summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-01-25 21:46:29 +0100
committerTom Smeding <tom@tomsmeding.com>2024-01-25 21:46:29 +0100
commit2a53042c1ce8b593a6178696c03ac77c6b76b395 (patch)
treecb7a8a82ac254980d3fbb1911d4a7d891647a561
parent39b899b4951be5b78058d5c0e35977b065a63951 (diff)
Finish rewrite
-rw-r--r--src/AST/Pretty.hs1
-rw-r--r--src/AST/Weaken.hs10
-rw-r--r--src/CHAD.hs149
-rw-r--r--src/Example.hs1
-rw-r--r--src/Lemmas.hs3
5 files changed, 82 insertions, 82 deletions
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 6bc75ed..1ffa980 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -15,6 +15,7 @@ import Data.Functor.Const
import AST
import AST.Count
+import Data
data Val f env where
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index c7668e7..07a90dc 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -40,10 +40,14 @@ instance KnownListSpine list => KnownListSpine (t : list) where knownListSpine =
wSinks' :: forall list env. KnownListSpine list => env :> Append list env
wSinks' = wSinks (knownListSpine :: SList (Const ()) list)
-wSinks :: SList f list' -> env :> Append list' env
+wSinks :: forall env bs f. SList f bs -> env :> Append bs env
wSinks SNil = WId
wSinks (SCons _ spine) = WSink .> wSinks spine
+wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2
+wCopies SNil w = w
+wCopies (SCons _ spine) w = WCopy (wCopies spine w)
+
wStack :: forall env b1 b2. b1 :> b2 -> Append b1 env :> Append b2 env
wStack WId = WId
wStack WSink = WSink
@@ -52,6 +56,6 @@ wStack (WPop w) = WPop (wStack @env w)
wStack (WThen w1 w2) = WThen (wStack @env w1) (wStack @env w2)
wStack (WClosed s) = wSinks s
-wRaiseAbove :: SList f env1 -> SList (Const ()) env -> env1 :> Append env1 env
-wRaiseAbove SNil env = WClosed env
+wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env
+wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env)
wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 642f58f..d232fee 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -13,12 +13,22 @@
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE UndecidableInstances #-}
-module CHAD where
+
+-- I want to bring various type variables in scope using type annotations in
+-- patterns, but I don't want to have to mention all the other type parameters
+-- of the types in question as well then. Partial type signatures (with '_') are
+-- useful here.
+{-# OPTIONS -Wno-partial-type-signatures #-}
+module CHAD (
+ drev,
+ freezeRet,
+ Storage(..),
+ Descr(..),
+) where
import Data.Bifunctor (first, second)
import Data.Functor.Const
import Data.Kind (Type)
-import Data.Proxy
import Data.Some
import GHC.TypeLits (Symbol)
@@ -65,26 +75,25 @@ bconcat b1 (BPush (b2 :: Bindings f (Append binds1 env) binds2C) (t, x))
-- -> (forall env12. Bindings f env env12 -> r) -> r
-- bconcat' wf b1 b2 k = weakenBindings wf (sinkWithBindings b1) b2 $ \b2' _ -> k (bconcat b1 b2')
-type family Snoc l x where
- Snoc '[] x = '[x]
- Snoc (y : ys) x = y : Snoc ys x
+-- type family Snoc l x where
+-- Snoc '[] x = '[x]
+-- Snoc (y : ys) x = y : Snoc ys x
-bsnoc :: (forall env1 env2 t'. env1 :> env2 -> f env1 t' -> f env2 t')
- -> STy t -> f env t -> Bindings f env binds -> (Bindings f env (Snoc binds t), Append binds env :> Append (Snoc binds t) env)
-bsnoc _ t x BTop = (BPush BTop (t, x), WSink)
-bsnoc wf t x (BPush b (t', y)) =
- let (b', w) = bsnoc wf t x b
- in (BPush b' (t', wf w y), WCopy w)
+-- bsnoc :: (forall env1 env2 t'. env1 :> env2 -> f env1 t' -> f env2 t')
+-- -> STy t -> f env t -> Bindings f env binds -> (Bindings f env (Snoc binds t), Append binds env :> Append (Snoc binds t) env)
+-- bsnoc _ t x BTop = (BPush BTop (t, x), WSink)
+-- bsnoc wf t x (BPush b (t', y)) =
+-- let (b', w) = bsnoc wf t x b
+-- in (BPush b' (t', wf w y), WCopy w)
type family Tape binds where
Tape '[] = TNil
Tape (t : ts) = TPair t (Tape ts)
--- TODO: The problem is that in the 3rd field of TupBinds, we should reconstruct a stack of let bindings from the tape, but we can't directly without having quadratic code size (due to nested projections). Instead we should produce _two_ Bindings there: one to an existential intermediate 'tempbinds', and another that picks up from there and creates 'binds'.
-data TupBinds f env binds =
- TupBinds (SList STy binds)
- (forall env2. Append binds env :> env2 -> Ex env2 (Tape binds))
- (forall env2. Idx env2 (Tape binds) -> Bindings f env2 binds)
+-- data TupBinds f env binds =
+-- TupBinds (SList STy binds)
+-- (forall env2. Append binds env :> env2 -> Ex env2 (Tape binds))
+-- (forall env2. Idx env2 (Tape binds) -> Bindings f env2 binds)
bindingsBinds :: Bindings f env binds -> SList STy binds
bindingsBinds BTop = SNil
@@ -115,7 +124,9 @@ bindingsCollect (BPush binds (t, _)) w =
-- reconFromTape (SCons t ts) (BPush unfbinds (_, e)) = _
-- reconFromTape SCons{} BTop = error "unreachable"
--- TODO: This function produces quadratic code, but it must be linear. Need to fix this!
+-------------------------------------------------------------------------------------------
+-- TODO: This function produces quadratic code, but it must be linear. Need to fix this! --
+-------------------------------------------------------------------------------------------
reconstructBindings :: SList STy binds -> Ex env (Tape binds) -> Bindings Ex env binds
reconstructBindings SNil _ = BTop
reconstructBindings (SCons t ts) e = BPush (reconstructBindings ts (ESnd ext e)) (t, weakenExpr (sinkOver ts) (EFst ext e))
@@ -258,21 +269,6 @@ zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
zeroTup SNil = ENil ext
zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t)
-onehotTup :: SList STy env0 -> Idx env0 t -> Ex env (D2 t) -> Ex env (Tup (D2E env0))
-onehotTup (SCons _ env) IZ d = EPair ext (zeroTup env) d
-onehotTup (SCons t env) (IS i) d = EPair ext (onehotTup env i d) (zero t)
-onehotTup SNil i _ = case i of {}
-
-plusTup :: SList STy env0 -> Ex env (Tup (D2E env0)) -> Ex env (Tup (D2E env0)) -> Ex env (Tup (D2E env0))
-plusTup SNil _ _ = ENil ext
-plusTup env0@(SCons t env) a b =
- ELet ext a $
- ELet ext (weakenExpr WSink b) $
- EPair ext (plusTup env (EFst ext (EVar ext (tTup (d2e env0)) (IS IZ)))
- (EFst ext (EVar ext (tTup (d2e env0)) IZ)))
- (plus t (ESnd ext (EVar ext (tTup (d2e env0)) (IS IZ)))
- (ESnd ext (EVar ext (tTup (d2e env0)) IZ)))
-
data Subenv env env' where
SETop :: Subenv '[] '[]
SEYes :: Subenv env env' -> Subenv (t : env) (t : env')
@@ -397,7 +393,7 @@ weakenRets w (Rets binds list) =
rebaseRetPair :: forall env b1 b2 env0 sto t f. SList f b1 -> SList f b2 -> RetPair env0 sto (Append b1 env) b2 t -> RetPair env0 sto env (Append b2 b1) t
rebaseRetPair b1 b2 (RetPair p sub d)
| Refl <- lemAppendAssoc @b2 @b1 @env =
- RetPair p sub (weakenExpr (WCopy (wRaiseAbove b2 (slistMap (\_ -> Const ()) b1))) d)
+ RetPair p sub (weakenExpr (WCopy (wRaiseAbove b2 b1)) d)
retConcat :: forall env0 sto list. SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list
retConcat SNil = Rets BTop SNil
@@ -525,7 +521,6 @@ d2e :: SList STy env -> SList STy (D2E env)
d2e SNil = SNil
d2e (SCons t ts) = SCons (d2 t) (d2e ts)
-{-
drev :: forall env sto t.
Descr env sto
-> (forall env' sto' t'. Descr env' sto' -> STy t' -> Some Storage)
@@ -546,10 +541,12 @@ drev des policy = \case
(EMReturn d2acc (EPair ext (ENil ext) (EVar ext (d2 t) IZ)))
ELet _ rhs body
- | Ret rhs0 rhs1 subRHS rhs2 <- drev des policy rhs
+ | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des policy rhs
, Some storage <- policy des (typeOf rhs)
- , Ret body0 body1 subBody body2 <- drev (des `DPush` (typeOf rhs, storage)) policy body ->
- weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 $ \body0' wbody0' ->
+ , Ret (body0 :: Bindings _ _ body_shbinds) body1 subBody body2 <- drev (des `DPush` (typeOf rhs, storage)) policy body
+ , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
+ , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
+ , Refl <- lemAppendNil @body_shbinds ->
unscope des (typeOf rhs) storage subBody body2 $ \subBody' body2' ->
subenvPlus (select SMerge des) subRHS subBody' $ \subBoth _ _ plus_RHS_Body ->
let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody'))) (d2 (typeOf rhs)) in
@@ -557,10 +554,10 @@ drev des policy = \case
(weakenExpr wbody0' body1)
subBoth
(EMBind
- (weakenExpr (WCopy wbody0') body2')
+ (weakenExpr (WCopy (wRaiseAbove (bindingsBinds body0) (SCons (typeOf rhs1) (bindingsBinds rhs0)))) body2')
(EMBind
(ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
- weakenExpr (WCopy (wSinks' @[_,_] .> WPop (sinkWithBindings body0'))) rhs2)
+ weakenExpr (WCopy (wSinks' @[_,_] .> WPop @d1_a (sinkWithBindings body0'))) rhs2)
(EMReturn d2acc (plus_RHS_Body
(EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
(EFst ext (EVar ext bodyResType (IS IZ)))))))
@@ -626,61 +623,56 @@ drev des policy = \case
(EError (STEVM d2acc (tTup (d2e (subList (select SMerge des) sub)))) "inr<-dinl")
(weakenExpr (WCopy (wSinks' @[_,_])) e2)))
- ECase _ e a b
+ ECase _ e (a :: Ex _ t) b
| STEither t1 t2 <- typeOf e
- , Ret e0 e1 subE e2 <- drev des policy e
+ , Ret (e0 :: Bindings _ _ e_binds) e1 subE e2 <- drev des policy e
, Some storageA <- policy des t1
, Some storageB <- policy des t2
- , Ret a0 a1 subA a2 <- drev (des `DPush` (t1, storageA)) policy a
- , Ret b0 b1 subB b2 <- drev (des `DPush` (t2, storageB)) policy b
- , TupBinds tapeA collectA reconA <- tupBinds a0
- , TupBinds tapeB collectB reconB <- tupBinds b0
- , let tPrimal = STPair (d1 (typeOf a)) (STEither tapeA tapeB) ->
- weakenBindings weakenExpr (WCopy (WSink .> sinkWithBindings e0)) a0 $ \a0' wa0' ->
- weakenBindings weakenExpr (WCopy (WSink .> sinkWithBindings e0)) b0 $ \b0' wb0' ->
+ , Ret (a0 :: Bindings _ _ rhs_a_binds) a1 subA a2 <- drev (des `DPush` (t1, storageA)) policy a
+ , Ret (b0 :: Bindings _ _ rhs_b_binds) b1 subB b2 <- drev (des `DPush` (t2, storageB)) policy b
+ , let tapeA = tapeTy (bindingsBinds a0)
+ , let tapeB = tapeTy (bindingsBinds b0)
+ , let collectA = bindingsCollect a0
+ , let collectB = bindingsCollect b0
+ , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB)
+ , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0
+ , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 ->
unscope des t1 storageA subA a2 $ \subA' a2' ->
unscope des t2 storageB subB b2 $ \subB' b2' ->
subenvPlus (select SMerge des) subA' subB' $ \subAB sAB_A sAB_B _ ->
subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E ->
let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STEither (d2 t1) (d2 t2)) in
Ret (e0 `BPush`
- (d1 (typeOf e), e1) `BPush`
(tPrimal,
- ECase ext (EVar ext (d1 (typeOf e)) IZ)
+ ECase ext e1
(letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'))))
(letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'))))))
(EFst ext (EVar ext tPrimal IZ))
subOut
(EMBind
- (ECase ext (EVar ext (STEither (d1 t1) (d1 t2)) (IS (IS IZ)))
- (ECase ext (ESnd ext (EVar ext tPrimal (IS (IS IZ))))
- (case reconA (WSink .> WCopy (wSinks' @[_,_,_] .> sinkWithBindings e0)) IZ of
- TupBindsReconstruct rebinds wrebinds ->
- letBinds rebinds $
- ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $
- EMBind (weakenExpr (WCopy wrebinds) a2')
- (EMReturn d2acc
- (EPair ext
- (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $
- EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))
- (EInl ext (d2 t2)
- (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))))))
- (EError (STEVM d2acc tCaseRet) "dcase l/rtape"))
- (ECase ext (ESnd ext (EVar ext tPrimal (IS (IS IZ))))
- (EError (STEVM d2acc tCaseRet) "dcase r/ltape")
- (case reconB (WSink .> WCopy (wSinks' @[_,_,_] .> sinkWithBindings e0)) IZ of
- TupBindsReconstruct rebinds wrebinds ->
- letBinds rebinds $
- ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $
- EMBind (weakenExpr (WCopy wrebinds) b2')
- (EMReturn d2acc
- (EPair ext
- (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $
- EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))
- (EInr ext (d2 t1)
- (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))))))))
+ (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ)))
+ (let rebinds = reconstructBindings (bindingsBinds a0) (EVar ext tapeA IZ)
+ in letBinds rebinds $
+ ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : e_binds) (bindingsBinds a0) @> IS IZ)) $
+ EMBind (weakenExpr (WCopy (wRaiseAbove (bindingsBinds a0) (tapeA `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0))) a2')
+ (EMReturn d2acc
+ (EPair ext
+ (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $
+ EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))
+ (EInl ext (d2 t2)
+ (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))))))
+ (let rebinds = reconstructBindings (bindingsBinds b0) (EVar ext tapeB IZ)
+ in letBinds rebinds $
+ ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : e_binds) (bindingsBinds b0) @> IS IZ)) $
+ EMBind (weakenExpr (WCopy (wRaiseAbove (bindingsBinds b0) (tapeB `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0))) b2')
+ (EMReturn d2acc
+ (EPair ext
+ (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $
+ EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))
+ (EInr ext (d2 t1)
+ (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ)))))))
(EMBind (ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $
- weakenExpr (WCopy (wSinks' @[_,_,_,_])) e2) $
+ weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $
EMReturn d2acc $
plus_AB_E
(EFst ext (EVar ext tCaseRet (IS IZ)))
@@ -713,4 +705,3 @@ drev des policy = \case
where
d2acc = d2e (select SAccum des)
--}
diff --git a/src/Example.hs b/src/Example.hs
index 643b82f..30031c0 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -6,6 +6,7 @@ import Data.Some
import AST
import AST.Pretty
import CHAD
+import Data
import Simplify
diff --git a/src/Lemmas.hs b/src/Lemmas.hs
index 7dbf680..cb62155 100644
--- a/src/Lemmas.hs
+++ b/src/Lemmas.hs
@@ -11,5 +11,8 @@ import Unsafe.Coerce (unsafeCoerce)
import AST.Weaken
+lemAppendNil :: Append a '[] :~: a
+lemAppendNil = unsafeCoerce Refl
+
lemAppendAssoc :: Append a (Append b c) :~: Append (Append a b) c
lemAppendAssoc = unsafeCoerce Refl