summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2024-09-02 17:49:54 +0200
committerTom Smeding <t.j.smeding@uu.nl>2024-09-02 17:50:12 +0200
commit7d44dcc2ca2c5c16e1ab4737ef6b2877214767ed (patch)
tree42e8b9292403f9ce3a6f04a15ebd62a766880339
parent1f7ed2ee02222108684cfde8078e7a182f734a61 (diff)
WIP autoWeak
-rw-r--r--chad-fast.cabal2
-rw-r--r--src/AST.hs5
-rw-r--r--src/AST/Weaken.hs139
-rw-r--r--src/CHAD.hs82
-rw-r--r--src/Data.hs6
-rw-r--r--src/Lemmas.hs7
6 files changed, 218 insertions, 23 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 2e7ee22..ca3a2aa 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -23,7 +23,7 @@ library
Simplify
other-modules:
build-depends:
- base >= 4.14 && < 4.19,
+ base >= 4.19 && < 4.21,
containers,
template-haskell,
transformers,
diff --git a/src/AST.hs b/src/AST.hs
index c191651..d9f5ef7 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -261,11 +261,6 @@ idx2int :: Idx env t -> Int
idx2int IZ = 0
idx2int (IS n) = 1 + idx2int n
-splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t)
-splitIdx SNil i = Right i
-splitIdx (SCons _ _) IZ = Left IZ
-splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i)
-
class KnownScalTy t where knownScalTy :: SScalTy t
instance KnownScalTy TI32 where knownScalTy = STI32
instance KnownScalTy TI64 where knownScalTy = STI64
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index 78577ee..6c66b07 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -1,23 +1,34 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE PartialTypeSignatures #-}
+{-# OPTIONS -Wno-partial-type-signatures #-}
+
-- The reason why this is a separate module with "little" in it:
{-# LANGUAGE AllowAmbiguousTypes #-}
-module AST.Weaken where
+module AST.Weaken (module AST.Weaken, Append) where
+import Data.Bifunctor (first)
import Data.Functor.Const
import Data.Kind (Type)
+import GHC.TypeLits
+import Unsafe.Coerce
import Data
+import Lemmas
type Idx :: [k] -> k -> Type
@@ -26,6 +37,11 @@ data Idx env t where
IS :: Idx env t -> Idx (a : env) t
deriving instance Show (Idx env t)
+splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t)
+splitIdx SNil i = Right i
+splitIdx (SCons _ _) IZ = Left IZ
+splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i)
+
data env :> env' where
WId :: env :> env
WSink :: forall t env. env :> (t : env)
@@ -34,28 +50,38 @@ data env :> env' where
WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3
WClosed :: SList (Const ()) env -> '[] :> env
WIdx :: Idx env t -> (t : env) :> env
+ WPick :: forall t pre env env'. SList (Const ()) pre -> env :> env'
+ -> Append pre (t : env) :> t : Append pre env'
+ WSwap :: SList (Const ()) as -> SList (Const ()) bs -> SList (Const ()) env
+ -> Append as (Append bs env) :> Append bs (Append as env)
deriving instance Show (env :> env')
+infix 4 :>
infixr 2 @>
(@>) :: env :> env' -> Idx env t -> Idx env' t
WId @> i = i
WSink @> i = IS i
WCopy _ @> IZ = IZ
-WCopy w @> (IS i) = IS (w @> i)
+WCopy w @> IS i = IS (w @> i)
WPop w @> i = w @> IS i
WThen w1 w2 @> i = w2 @> w1 @> i
WClosed _ @> i = case i of {}
WIdx j @> IZ = j
WIdx _ @> IS i = i
+WPick SNil w @> i = WCopy w @> i
+WPick (_ `SCons` _) _ @> IZ = IS IZ
+WPick @t (_ `SCons` pre) w @> IS i = WCopy WSink .> WPick @t pre w @> i
+WSwap (as :: SList _ as) (bs :: SList _ bs) (env :: SList _ env) @> i =
+ case splitIdx @(Append bs env) as i of
+ Left i' -> wSinks bs .> wRaiseAbove as env @> i'
+ Right j -> case splitIdx @env bs j of
+ Left j' -> wRaiseAbove bs (sappend as env) @> j'
+ Right k -> wSinks bs .> wSinks as @> k
infixr 3 .>
(.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3
(.>) = flip WThen
-type family Append a b where
- Append '[] l = l
- Append (x : xs) l = x : Append xs l
-
class KnownListSpine list where knownListSpine :: SList (Const ()) list
instance KnownListSpine '[] where knownListSpine = SNil
instance KnownListSpine list => KnownListSpine (t : list) where knownListSpine = SCons (Const ()) knownListSpine
@@ -83,7 +109,108 @@ wStack (WIdx i) = WIdx (goIdx i)
goIdx :: Idx b t -> Idx (Append b env) t
goIdx IZ = IZ
goIdx (IS i') = IS (goIdx i')
+wStack (WPick @t @_ @env1 @env2 (pre :: SList (Const ()) pre) w)
+ | Refl <- lemAppendAssoc @pre @env2 @env
+ , Refl <- lemAppendAssoc @pre @(t : env1) @env
+ = WPick @t @_ pre (wStack @env w)
+wStack WSwap{} = error "OOPS"
wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env
wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env)
wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env)
+
+
+type family Lookup name list where
+ Lookup name ('(name, x) : _) = x
+ Lookup name (_ : list) = Lookup name list
+
+data Layout (segments :: [(Symbol, [t])]) (env :: [t]) where
+ LSeg :: forall name segments. KnownSymbol name => Layout segments (Lookup name segments)
+ (:++:) :: Layout segments env1 -> Layout segments env2 -> Layout segments (Append env1 env2)
+
+data SSegments (segments :: [(Symbol, [t])]) where
+ SSegNil :: SSegments '[]
+ SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list)
+
+class KnownSegments (segments :: [(Symbol, [t])]) where
+ knownSegments :: SSegments segments
+
+instance KnownSegments '[] where
+ knownSegments = SSegNil
+
+instance (KnownSymbol name, KnownListSpine ts, KnownSegments list)
+ => KnownSegments ('(name, ts) : list) where
+ knownSegments = SSegCons symbolSing knownListSpine knownSegments
+
+segmentLookup :: forall segments name. SSegments segments -> SSymbol name -> SList (Const ()) (Lookup name segments)
+segmentLookup = \segs name -> case go segs name of
+ Just ts -> ts
+ Nothing -> error $ "Segment not found: " ++ fromSSymbol name
+ where
+ go :: forall segs'. SSegments segs' -> SSymbol name -> Maybe (SList (Const ()) (Lookup name segs'))
+ go SSegNil _ = Nothing
+ go (SSegCons n@(SSymbol @n) (ts :: SList _ ts) (sseg :: SSegments rest)) name@SSymbol =
+ case sameSymbol n name of
+ Just Refl ->
+ case go sseg name of
+ Nothing -> Just ts
+ Just _ -> error $ "Duplicate segment with name " ++ fromSSymbol name
+ Nothing ->
+ case unsafeCoerce Refl :: (Lookup name ('(n, ts) : rest) :~: Lookup name rest) of
+ Refl -> go sseg name
+
+data LinLayout (segments :: [(Symbol, [t])]) (env :: [t]) where
+ LinEnd :: LinLayout segments '[]
+ LinApp :: SSymbol name -> LinLayout segments env -> LinLayout segments (Append (Lookup name segments) env)
+
+linLayoutAppend :: LinLayout segments env1 -> LinLayout segments env2 -> LinLayout segments (Append env1 env2)
+linLayoutAppend LinEnd lin = lin
+linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout segments env1')) (lin2 :: LinLayout _ env2)
+ | Refl <- lemAppendAssoc @(Lookup name segments) @env1' @env2
+ = LinApp name (linLayoutAppend lin1 lin2)
+
+lineariseLayout :: Layout segments env -> LinLayout segments env
+lineariseLayout (LSeg @name :: Layout _ seg)
+ | Refl <- lemAppendNil @seg
+ = LinApp (symbolSing @name) LinEnd
+lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2
+
+linLayoutEnv :: SSegments segments -> LinLayout segments env -> SList (Const ()) env
+linLayoutEnv _ LinEnd = SNil
+linLayoutEnv segs (LinApp name lin) = sappend (segmentLookup segs name) (linLayoutEnv segs lin)
+
+autoWeak :: forall segments env1 env2. KnownSegments segments
+ => Layout segments env1 -> Layout segments env2 -> env1 :> env2
+autoWeak ly1 ly2 = sortLinLayouts knownSegments (lineariseLayout ly1) (lineariseLayout ly2)
+
+pullDown :: SSegments segments -> SSymbol name -> LinLayout segments env
+ -> r -- Name was not found in source
+ -> (forall env'. LinLayout segments env' -> env :> Append (Lookup name segments) env' -> r)
+ -> r
+pullDown segs name@SSymbol linlayout kNotFound k =
+ case linlayout of
+ LinEnd -> kNotFound
+ LinApp n'@SSymbol lin
+ | Just Refl <- sameSymbol name n' -> k lin WId
+ | otherwise ->
+ pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ env') w ->
+ k (LinApp n' lin') (WSwap (segmentLookup segs n') (segmentLookup segs name) (linLayoutEnv segs lin')
+ .> wCopies (segmentLookup segs n') w)
+
+sortLinLayouts :: forall segments env1 env2.
+ SSegments segments
+ -> LinLayout segments env1 -> LinLayout segments env2 -> env1 :> env2
+sortLinLayouts _ LinEnd LinEnd = WId
+sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2)
+ | Just Refl <- sameSymbol name1 name2 = wCopies (segmentLookup segs name1) (sortLinLayouts segs tail1 tail2)
+ | otherwise =
+ pullDown segs name2 lin1
+ (wSinks (segmentLookup segs name2) .> sortLinLayouts segs lin1 tail2)
+ (\tail1' w ->
+ -- We've pulled down name2 in lin1 so that it's at the head; the
+ -- resulting modified tail is tail1'. Thus now we have (name2 : tail1')
+ -- vs (name2 : tail2). Thus we continue sorting tail1' vs tail2, and
+ -- wCopies the name2 on top of that.
+ wCopies (segmentLookup segs name2) (sortLinLayouts segs tail1' tail2) .> w)
+sortLinLayouts _ LinEnd LinApp{} = error "Unequal number of segments: more in target than in source"
+sortLinLayouts _ LinApp{} LinEnd = error "Unequal number of segments: more in source than in target"
diff --git a/src/CHAD.hs b/src/CHAD.hs
index a6dd9ff..45d2d08 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -27,10 +27,12 @@ module CHAD (
) where
import Data.Bifunctor (first, second)
+import Data.Functor.Const
import Data.Kind (Type)
import GHC.TypeLits (Symbol)
import AST
+import AST.Count
import Data
import Lemmas
@@ -142,10 +144,6 @@ sreverse :: SList f ts -> SList f (Reverse ts)
sreverse SNil = SNil
sreverse (SCons t ts) = ssnoc (sreverse ts) t
-sappend :: SList f l1 -> SList f l2 -> SList f (Append l1 l2)
-sappend SNil l = l
-sappend (SCons x xs) l = SCons x (sappend xs l)
-
stapeUnfoldings :: SList STy ts -> SList STy (TapeUnfoldings ts)
stapeUnfoldings SNil = SNil
stapeUnfoldings (SCons _ ts) = SCons (tapeTy ts) (stapeUnfoldings ts)
@@ -408,6 +406,63 @@ zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
zeroTup SNil = ENil ext
zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t)
+accumPromote :: forall t env sto proxy r.
+ proxy t
+ -> Descr env sto
+ -> OccEnv env
+ -> (forall stoRepl envPro.
+ Descr env stoRepl
+ -- ^ A revised environment description that switches
+ -- arrays (used in the OccEnv) that are currently on
+ -- "merge" storage, to "accum" storage.
+ -> Subenv (Select env sto "merge") (Select env stoRepl "merge")
+ -- ^ The new storage has fewer "merge"-storage entries.
+ -> SList STy envPro
+ -- ^ New entries on top of the original dual environment,
+ -- that house the accumulators for the promoted arrays in
+ -- the original environment.
+ -> (forall shbinds.
+ SList STy shbinds
+ -> (D2 t : Append shbinds (D2AcE (Select env stoRepl "accum")))
+ :> Append envPro (D2 t : Append shbinds (D2AcE (Select env sto "accum"))))
+ -- ^ A weakening that converts a computation in the
+ -- revised environment to one in the original environment
+ -- extended with some accumulators.
+ -> r)
+ -> r
+accumPromote _ DTop _ k = k DTop SETop SNil (\_ -> WId)
+accumPromote _ descr OccEnd k = k descr (subenvAll (select SMerge descr)) SNil (\_ -> WId)
+accumPromote pty (descr `DPush` (t, sto)) (occenv `OccPush` occ) k =
+ accumPromote pty descr occenv $ \(storepl :: Descr env1 stoRepl) mergesub (envpro :: SList _ envPro) wf ->
+ case (t, sto, occ) of
+ (STArr (arrn :: SNat arrn) (arrt :: STy arrt), SMerge, Occ _ c) | c > Zero ->
+ k (storepl `DPush` (t, SAccum))
+ (SENo mergesub)
+ (STAccum arrn arrt `SCons` envpro)
+ (\(shbinds :: SList _ shbinds) ->
+ let shbindsC = slistMap (\_ -> Const ()) shbinds
+ in
+ -- wf:
+ -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ -- WCopy wf:
+ -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ -- WPICK: ^ THESE TWO ||
+ -- goal: | ARE EQUAL ||
+ -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ WCopy (wf shbinds)
+ .> WPick @(TAccum arrn arrt) @(D2 t : shbinds) (Const () `SCons` shbindsC)
+ (WId @(D2AcE (Select env1 stoRepl "accum"))))
+
+ (_, SAccum, _) ->
+ k (storepl `DPush` (t, SAccum))
+ mergesub
+ envpro
+ (\shbinds -> _ (wf shbinds))
+
+ _ -> _
+
+-- | @env'@ is a subset of @env@: each element of @env@ is either included in
+-- @env'@ ('SEYes') or not included in @env'@ ('SENo').
data Subenv env env' where
SETop :: Subenv '[] '[]
SEYes :: Subenv env env' -> Subenv (t : env) (t : env')
@@ -440,11 +495,15 @@ subList SNil SETop = SNil
subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub)
subList (SCons _ xs) (SENo sub) = subList xs sub
-subenvNone :: SList STy env -> Subenv env '[]
+subenvNone :: SList f env -> Subenv env '[]
subenvNone SNil = SETop
subenvNone (SCons _ env) = SENo (subenvNone env)
-subenvOnehot :: SList STy env -> Idx env t -> Subenv env '[t]
+subenvAll :: SList f env -> Subenv env env
+subenvAll SNil = SETop
+subenvAll (SCons _ env) = SEYes (subenvAll env)
+
+subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t]
subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env)
subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i)
subenvOnehot SNil i = case i of {}
@@ -525,10 +584,15 @@ weakenRets w (Rets binds list) =
let (binds', _) = weakenBindings weakenExpr w binds
in Rets binds' (slistMap (weakenRetPair (bindingsBinds binds) w) list)
-rebaseRetPair :: forall env b1 b2 env0 sto t f. SList f b1 -> SList f b2 -> RetPair env0 sto (Append b1 env) b2 t -> RetPair env0 sto env (Append b2 b1) t
-rebaseRetPair b1 b2 (RetPair p sub d)
+rebaseRetPair :: forall env b1 b2 env0 sto t f.
+ SList f env0 -> SList f b1 -> SList f b2
+ -> RetPair env0 sto (Append b1 env) b2 t -> RetPair env0 sto env (Append b2 b1) t
+rebaseRetPair env b1 b2 (RetPair p sub d)
| Refl <- lemAppendAssoc @b2 @b1 @env =
- RetPair p sub (weakenExpr (WCopy (wStack @(D2AcE (Select env0 sto "accum")) (wRaiseAbove b2 b1))) d)
+ RetPair p sub (weakenExpr (autoWeak @['("d", '[D2 t]), '("b2", b2), '("b1", b1), '("tl", D2AcE (Select env0 sto "accum"))]
+ (LSeg @"d" :++: (LSeg @"b2" :++: LSeg @"tl"))
+ (LSeg @"d" :++: ((LSeg @"b2" :++: LSeg @"b1") :++: LSeg @"tl")))
+ d)
retConcat :: forall env0 sto list. SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list
retConcat SNil = Rets BTop SNil
diff --git a/src/Data.hs b/src/Data.hs
index c3381d5..a3f4c3c 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -8,6 +8,8 @@
{-# LANGUAGE TypeOperators #-}
module Data where
+import Lemmas (Append)
+
data SList f l where
SNil :: SList f '[]
@@ -19,6 +21,10 @@ slistMap :: (forall t. f t -> g t) -> SList f list -> SList g list
slistMap _ SNil = SNil
slistMap f (SCons x list) = SCons (f x) (slistMap f list)
+sappend :: SList f l1 -> SList f l2 -> SList f (Append l1 l2)
+sappend SNil l = l
+sappend (SCons x xs) l = SCons x (sappend xs l)
+
data Nat = Z | S Nat
deriving (Show, Eq, Ord)
diff --git a/src/Lemmas.hs b/src/Lemmas.hs
index cb62155..31a43ed 100644
--- a/src/Lemmas.hs
+++ b/src/Lemmas.hs
@@ -1,6 +1,7 @@
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
module Lemmas (module Lemmas, (:~:)(Refl)) where
@@ -8,8 +9,10 @@ module Lemmas (module Lemmas, (:~:)(Refl)) where
import Data.Type.Equality
import Unsafe.Coerce (unsafeCoerce)
-import AST.Weaken
+type family Append a b where
+ Append '[] l = l
+ Append (x : xs) l = x : Append xs l
lemAppendNil :: Append a '[] :~: a
lemAppendNil = unsafeCoerce Refl