diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-04-06 17:07:22 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-04-06 17:07:22 +0200 |
commit | 0a9e6dfc1accf9dc0254f0c720f633dab6e71f42 (patch) | |
tree | 754eaeecf01e554d7ad904c27a9b665879441ca0 | |
parent | b6c1d3a9d0651aa25ea5f03d514a214a3347f7a4 (diff) |
-rw-r--r-- | src/Analysis/Identity.hs | 10 | ||||
-rw-r--r-- | src/CHAD.hs | 174 | ||||
-rw-r--r-- | src/CHAD/EnvDescr.hs | 33 | ||||
-rw-r--r-- | src/CHAD/Top.hs | 12 | ||||
-rw-r--r-- | src/Compile.hs | 4 | ||||
-rw-r--r-- | src/Compile/Exec.hs | 3 | ||||
-rw-r--r-- | src/Data/VarMap.hs | 30 | ||||
-rw-r--r-- | src/Example.hs | 24 |
8 files changed, 168 insertions, 122 deletions
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 54f7cd2..186ab71 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -2,10 +2,12 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} module Analysis.Identity ( - ValId(..), identityAnalysis, identityAnalysis', + ValId(..), + validSplitEither, ) where import Data.Foldable (toList) @@ -31,6 +33,7 @@ data ValId t where VIArr :: Int -> Vec n Int -> ValId (TArr n t) VIScal :: Int -> ValId (TScal t) VIAccum :: Int -> ValId (TAccum t) +deriving instance Show (ValId t) instance PrettyX ValId where prettyX = \case @@ -46,6 +49,11 @@ instance PrettyX ValId where VIScal i -> show i VIAccum i -> 'C' : show i +validSplitEither :: ValId (TEither a b) -> (Maybe (ValId a), Maybe (ValId b)) +validSplitEither (VIEither (Left v)) = (Just v, Nothing) +validSplitEither (VIEither (Right v)) = (Nothing, Just v) +validSplitEither (VIEither' v1 v2) = (Just v1, Just v2) + -- | Symbolic partial evaluation. identityAnalysis :: SList STy env -> Expr x env t -> Expr ValId env t identityAnalysis env term = runIdGen 0 $ do diff --git a/src/CHAD.hs b/src/CHAD.hs index 6a4d5f5..8db0410 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -36,7 +36,7 @@ import Data.Type.Bool (If) import Data.Type.Equality (type (==)) import GHC.Stack (HasCallStack) -import Analysis.Identity (ValId(..)) +import Analysis.Identity (ValId(..), validSplitEither) import AST import AST.Bindings import AST.Count @@ -294,18 +294,18 @@ data Idx2 env sto t | Idx2Di (Idx (Select env sto "discr") t) conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t -conv2Idx (DPush _ (_, SAccum)) IZ = Idx2Ac IZ -conv2Idx (DPush _ (_, SMerge)) IZ = Idx2Me IZ -conv2Idx (DPush _ (_, SDiscr)) IZ = Idx2Di IZ -conv2Idx (DPush des (_, SAccum)) (IS i) = +conv2Idx (DPush _ (_, _, SAccum)) IZ = Idx2Ac IZ +conv2Idx (DPush _ (_, _, SMerge)) IZ = Idx2Me IZ +conv2Idx (DPush _ (_, _, SDiscr)) IZ = Idx2Di IZ +conv2Idx (DPush des (_, _, SAccum)) (IS i) = case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j) Idx2Me j -> Idx2Me j Idx2Di j -> Idx2Di j -conv2Idx (DPush des (_, SMerge)) (IS i) = +conv2Idx (DPush des (_, _, SMerge)) (IS i) = case conv2Idx des i of Idx2Ac j -> Idx2Ac j Idx2Me j -> Idx2Me (IS j) Idx2Di j -> Idx2Di j -conv2Idx (DPush des (_, SDiscr)) (IS i) = +conv2Idx (DPush des (_, _, SDiscr)) (IS i) = case conv2Idx des i of Idx2Ac j -> Idx2Ac j Idx2Me j -> Idx2Me j Idx2Di j -> Idx2Di (IS j) @@ -376,6 +376,10 @@ assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" --------------------------------- ACCUMULATORS --------------------------------- +fromArrayValId :: Maybe (ValId t) -> Maybe Int +fromArrayValId (Just (VIArr i _)) = Just i +fromArrayValId _ = Nothing + accumPromote :: forall dt env sto proxy r. proxy dt -> Descr env sto @@ -384,8 +388,7 @@ accumPromote :: forall dt env sto proxy r. => Descr env stoRepl -- ^ A revised environment description that switches -- arrays (used in the OccEnv) that are currently on - -- "merge" storage, to "accum" storage. Any other "merge" - -- entries are deleted. + -- "merge" storage, to "accum" storage. -> SList STy envPro -- ^ New entries on top of the original dual environment, -- that house the accumulators for the promoted arrays in @@ -393,6 +396,12 @@ accumPromote :: forall dt env sto proxy r. -> Subenv (Select env sto "merge") envPro -- ^ The promoted entries were merge entries in the -- original environment. + -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum")) + -- ^ All entries that were accumulators are still + -- accumulators. + -> VarMap Int (D2AcE (Select env stoRepl "accum")) + -- ^ Accumulator map for _only_ the the newly allocated + -- accumulators. -> (forall shbinds. SList STy shbinds -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) @@ -402,57 +411,70 @@ accumPromote :: forall dt env sto proxy r. -- extended with some accumulators. -> r) -> r -accumPromote _ DTop k = k DTop SNil SETop (\_ -> WId) -accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub wf -> - case sto of - -- Accumulators are left as-is - SAccum -> - k (storepl `DPush` (t, SAccum)) - envpro - prosub - (\shbinds -> - autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum descr))) - (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) - (#pro :++: #d :++: #shb :++: #acc :++: #tl) - .> WCopy (wf shbinds) - .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum storepl))) - (#d :++: #shb :++: #acc :++: #tl) - (#acc :++: (#d :++: #shb :++: #tl))) - - SMerge -> case t of - -- Discrete values are left as-is - _ | isDiscrete t -> - k (storepl `DPush` (t, SDiscr)) - envpro - (SENo prosub) - wf - - -- Values with "merge" storage are promoted to an accumulator in envPro - _ -> - k (storepl `DPush` (t, SAccum)) - (t `SCons` envpro) - (SEYes prosub) - (\(shbinds :: SList _ shbinds) -> - let shbindsC = slistMap (\_ -> Const ()) shbinds - in - -- wf: - -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - -- WCopy wf: - -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - -- WPICK: ^ THESE TWO || - -- goal: | ARE EQUAL || - -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - WCopy (wf shbinds) - .> WPick @(TAccum t) @(D2 dt : shbinds) (Const () `SCons` shbindsC) - (WId @(D2AcE (Select env1 stoRepl "accum")))) - - -- Discrete values are left as-is, nothing to do - SDiscr -> - k (storepl `DPush` (t, SDiscr)) +accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId) +accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of + -- Accumulators are left as-is + SAccum -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SAccum)) + envpro + prosub + (SEYes accrevsub) + (VarMap.sink1 accumMap) + (\shbinds -> + autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum descr))) + (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) + (#pro :++: #d :++: #shb :++: #acc :++: #tl) + .> WCopy (wf shbinds) + .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum storepl))) + (#d :++: #shb :++: #acc :++: #tl) + (#acc :++: (#d :++: #shb :++: #tl))) + + SMerge -> case t of + -- Discrete values are left as-is + _ | isDiscrete t -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf -> + k (storepl `DPush` (t, vid, SDiscr)) envpro - prosub + (SENo prosub) + accrevsub + accumMap' wf + + -- Values with "merge" storage are promoted to an accumulator in envPro + _ -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SAccum)) + (t `SCons` envpro) + (SEYes prosub) + (SENo accrevsub) + (let accumMap' = VarMap.sink1 accumMap + in case fromArrayValId vid of + Just i -> VarMap.insert i (STAccum t) IZ accumMap' + Nothing -> accumMap') + (\(shbinds :: SList _ shbinds) -> + let shbindsC = slistMap (\_ -> Const ()) shbinds + in + -- wf: + -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WCopy wf: + -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WPICK: ^ THESE TWO || + -- goal: | ARE EQUAL || + -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + WCopy (wf shbinds) + .> WPick @(TAccum t) @(D2 dt : shbinds) (Const () `SCons` shbindsC) + (WId @(D2AcE (Select env1 stoRepl "accum")))) + + -- Discrete values are left as-is, nothing to do + SDiscr -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SDiscr)) + envpro + prosub + accrevsub + accumMap + wf where isDiscrete :: STy t' -> Bool isDiscrete = \case @@ -561,7 +583,7 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = drev :: forall env sto t. (?config :: CHADConfig) - => Descr env sto -> VarMap Int env + => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) -> Expr ValId env t -> Ret env sto t drev des accumMap = \case EVar _ t i -> @@ -590,7 +612,7 @@ drev des accumMap = \case ELet _ (rhs :: Expr _ _ a) body | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge - , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage body + , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) body , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> @@ -687,8 +709,9 @@ drev des accumMap = \case , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge - , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 a - , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 b + , let (bindids1, bindids2) = validSplitEither (extOf e) + , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 bindids1 a + , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 bindids2 b , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA) @@ -819,8 +842,9 @@ drev des accumMap = \case deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> - accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> - case drev (prodes `DPush` (shty, SDiscr)) (VarMap.sink1 (VarMap.subMap usedSub accumMap)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> + accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> + let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in + case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> case assertSubenvEmpty sub of { Refl -> let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in let collectexpr = bindingsCollect e0 subtapeE in @@ -1055,17 +1079,22 @@ deriving instance Show (RetScoped env0 sto a s t) drevScoped :: forall a s env sto t. (?config :: CHADConfig) - => Descr env sto -> VarMap Int env -> STy a -> Storage s + => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> STy a -> Storage s -> Maybe (ValId a) -> Expr ValId (a : env) t -> RetScoped env sto a s t -drevScoped des accumMap argty argsto expr - | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argsto)) (VarMap.sink1 accumMap) expr - = case argsto of - SMerge -> +drevScoped des accumMap argty argsto argids expr = case argsto of + SMerge + | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> case sub of SEYes sub' -> RetScoped e0 subtape e1 sub' e2 SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty)) - SAccum -> + + SAccum + | let accumMap' = case argids of + Just (VIArr i _) -> VarMap.insert i (STAccum argty) IZ (VarMap.sink1 accumMap) + _ -> VarMap.sink1 accumMap + , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' expr -> RetScoped e0 subtape e1 sub $ EWith ext argty (EZero ext argty) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) @@ -1075,4 +1104,7 @@ drevScoped des accumMap argty argsto expr (#d :++: #body :++: #ac :++: #tl) (#ac :++: #d :++: #body :++: #tl)) e2 - SDiscr -> RetScoped e0 subtape e1 sub e2 + + SDiscr + | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> + RetScoped e0 subtape e1 sub e2 diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/EnvDescr.hs index fcd91f7..de615a1 100644 --- a/src/CHAD/EnvDescr.hs +++ b/src/CHAD/EnvDescr.hs @@ -11,6 +11,7 @@ module CHAD.EnvDescr where import Data.Kind (Type) import GHC.TypeLits (Symbol) +import Analysis.Identity (ValId(..)) import AST.Env import AST.Types import CHAD.Types @@ -27,12 +28,12 @@ deriving instance Show (Storage s) -- | Environment description data Descr env sto where DTop :: Descr '[] '[] - DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto) + DPush :: Descr env sto -> (STy t, Maybe (ValId t), Storage s) -> Descr (t : env) (s : sto) deriving instance Show (Descr env sto) descrList :: Descr env sto -> SList STy env descrList DTop = SNil -descrList (des `DPush` (t, _)) = t `SCons` descrList des +descrList (des `DPush` (t, _, _)) = t `SCons` descrList des -- | This could have more precise typing on the output storage. subDescr :: Descr env sto -> Subenv env env' @@ -43,13 +44,13 @@ subDescr :: Descr env sto -> Subenv env env' -> r) -> r subDescr DTop SETop k = k DTop SETop SETop SETop -subDescr (des `DPush` (t, sto)) (SEYes sub) k = +subDescr (des `DPush` (t, vid, sto)) (SEYes sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of - SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) - SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) - SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e) -subDescr (des `DPush` (_, sto)) (SENo sub) k = + SMerge -> k (des' `DPush` (t, vid, sto)) (SEYes submerge) subaccum (SEYes subd1e) + SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYes subaccum) (SEYes subd1e) + SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYes subd1e) +subDescr (des `DPush` (_, _, sto)) (SENo sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) @@ -64,12 +65,12 @@ type family Select env sto s where select :: Storage s -> Descr env sto -> SList STy (Select env sto s) select _ DTop = SNil -select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) -select s@SMerge (DPush des (_, SAccum)) = select s des -select s@SDiscr (DPush des (_, SAccum)) = select s des -select s@SAccum (DPush des (_, SMerge)) = select s des -select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) -select s@SDiscr (DPush des (_, SMerge)) = select s des -select s@SAccum (DPush des (_, SDiscr)) = select s des -select s@SMerge (DPush des (_, SDiscr)) = select s des -select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des) +select s@SAccum (DPush des (t, _, SAccum)) = SCons t (select s des) +select s@SMerge (DPush des (_, _, SAccum)) = select s des +select s@SDiscr (DPush des (_, _, SAccum)) = select s des +select s@SAccum (DPush des (_, _, SMerge)) = select s des +select s@SMerge (DPush des (t, _, SMerge)) = SCons t (select s des) +select s@SDiscr (DPush des (_, _, SMerge)) = select s des +select s@SAccum (DPush des (_, _, SDiscr)) = select s des +select s@SMerge (DPush des (_, _, SDiscr)) = select s des +select s@SDiscr (DPush des (t, _, SDiscr)) = SCons t (select s des) diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index ea7449d..2c01178 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -28,7 +28,7 @@ type family MergeEnv env where mergeDescr :: SList STy env -> Descr env (MergeEnv env) mergeDescr SNil = DTop -mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, SMerge) +mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, Nothing, SMerge) mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[] mergeEnvNoAccum SNil = Refl @@ -41,8 +41,8 @@ mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r accumDescr SNil k = k DTop accumDescr (t `SCons` env) k = accumDescr env $ \des -> - if hasArrays t then k (des `DPush` (t, SAccum)) - else k (des `DPush` (t, SMerge)) + if hasArrays t then k (des `DPush` (t, Nothing, SAccum)) + else k (des `DPush` (t, Nothing, SMerge)) d1Identity :: STy t -> D1 t :~: t d1Identity = \case @@ -62,17 +62,17 @@ reassembleD2E :: Descr env sto -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge")))) -> Ex env' (Tup (D2E env)) reassembleD2E DTop _ = ENil ext -reassembleD2E (des `DPush` (_, SAccum)) e = +reassembleD2E (des `DPush` (_, _, SAccum)) e = ELet ext e $ EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ))) (ESnd ext (EVar ext (typeOf e) IZ)))) (ESnd ext (EFst ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (_, SMerge)) e = +reassembleD2E (des `DPush` (_, _, SMerge)) e = ELet ext e $ EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ)) (EFst ext (ESnd ext (EVar ext (typeOf e) IZ))))) (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (t, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero ext t) +reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero ext t) chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) chad config env (term :: Ex env t) diff --git a/src/Compile.hs b/src/Compile.hs index 7bbb043..2a184f7 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -7,7 +7,7 @@ {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeApplications #-} -module Compile (compile, debugCSource, debugRefc, emitChecks) where +module Compile (compile) where import Control.Applicative (empty) import Control.Monad (forM_, when, replicateM) @@ -71,7 +71,7 @@ compile :: SList STy env -> Ex env t compile = \env expr -> do let source = compileToString env expr when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>" - when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ source ++ "\x1B[0m>>>" + when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>" lib <- buildKernel source ["kernel"] let arg_metrics = reverse (unSList metricsSTy env) diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs index 5f90ea2..d708fc0 100644 --- a/src/Compile/Exec.hs +++ b/src/Compile/Exec.hs @@ -4,6 +4,9 @@ module Compile.Exec ( KernelLib, buildKernel, callKernelFun, + + -- * misc + lineNumbers, ) where import Control.Monad (when) diff --git a/src/Data/VarMap.hs b/src/Data/VarMap.hs index 16c2d27..9c10421 100644 --- a/src/Data/VarMap.hs +++ b/src/Data/VarMap.hs @@ -11,9 +11,11 @@ module Data.VarMap ( delete, TypedIdx(..), lookup, + disjointUnion, sink1, unsink1, subMap, + superMap, ) where import Prelude hiding (lookup) @@ -57,6 +59,11 @@ lookup k (VarMap off _ mp) = do idx <- unsafeInt2idx (i + off) return (Some (TypedIdx ty idx)) +disjointUnion :: Ord k => VarMap k env -> VarMap k env -> VarMap k env +disjointUnion (VarMap off1 cl1 m1) (VarMap off2 cl2 m2) | off1 == off2 = + VarMap off1 (min cl1 cl2) (Map.unionWith (error "VarMap.disjointUnion: overlapping keys") m1 m2) +disjointUnion vm1 vm2 = disjointUnion (cleanup vm1) (cleanup vm2) + sink1 :: VarMap k env -> VarMap k (t : env) sink1 (VarMap off interval mp) = VarMap (off + 1) interval mp @@ -78,13 +85,32 @@ subMap subenv = | otherwise = Nothing in \(VarMap off _ mp) -> VarMap 0 0 (Map.fromAscList . mapMaybe (modify off) . Map.toAscList $ mp) +superMap :: Eq k => Subenv env env' -> VarMap k env' -> VarMap k env +superMap subenv = + let loop :: Subenv env env' -> Int -> [Int] + loop SETop _ = [] + loop (SEYes sub) i = i : loop sub (i+1) + loop (SENo sub) i = loop sub (i+1) + + newIndices = VS.fromList $ loop subenv 0 + modify off (k, (ty, i)) + | i + off < 0 = Nothing + | i + off >= VS.length newIndices = error "VarMap.superMap: found negative indices in map" + | otherwise = let j = newIndices VS.! (i + off) + in if j == -1 then Nothing else Just (k, (ty, j)) + + in \(VarMap off _ mp) -> VarMap 0 0 (Map.fromAscList . mapMaybe (modify off) . Map.toAscList $ mp) + maybeCleanup :: VarMap k env -> VarMap k env -maybeCleanup (VarMap off interval mp) +maybeCleanup vm@(VarMap _ interval mp) | let sz = Map.size mp , sz > 0, 2 * interval >= 3 * sz - = VarMap off 0 (Map.filter (\(_, i) -> i + off >= 0) mp) + = cleanup vm maybeCleanup vm = vm +cleanup :: VarMap k env -> VarMap k env +cleanup (VarMap off _ mp) = VarMap 0 0 (Map.mapMaybe (\(t, i) -> if i + off >= 0 then Just (t, i + off) else Nothing) mp) + unsafeInt2idx :: Int -> Maybe (Idx env t) unsafeInt2idx = \n -> if n < 0 then Nothing else Just (go n) where diff --git a/src/Example.hs b/src/Example.hs index 2c710a1..4fa8d5a 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -30,11 +30,6 @@ bin op a b = EOp ext op (EPair ext a b) senv1 :: SList STy [TScal TF32, TScal TF32] senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil -descr1 :: Storage a -> Storage b - -> Descr [TScal TF32, TScal TF32] [b, a] -descr1 a b = DTop `DPush` (t, a) `DPush` (t, b) - where t = STScal STF32 - -- x y |- x * y + x -- -- let x3 = (x1, x2) @@ -82,25 +77,12 @@ ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32) ex4 = fromNamed $ lambda #x $ lambda #y $ body $ if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x) -senv5 :: SList STy [TScal TF32, TEither (TScal TF32) (TScal TF32)] -senv5 = knownEnv - -descr5 :: Storage a -> Storage b - -> Descr [TScal TF32, TEither (TScal TF32) (TScal TF32)] [b, a] -descr5 a b = DTop `DPush` (knownTy, a) `DPush` (knownTy, b) - -- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)} ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) ex5 = fromNamed $ lambda #x $ lambda #y $ body $ case_ #x (#a :-> #a * #y) (#b :-> #b * (#y + 1)) -senv6 :: SList STy [TScal TI64, TScal TF32] -senv6 = knownEnv - -descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"] -descr6 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge) - -- x:R n:I |- let a = unit x -- b = build1 n (\i. let c = idx0 a in c * c) -- in idx0 (b ! 3) @@ -110,12 +92,6 @@ ex6 = fromNamed $ lambda #x $ lambda #n $ body $ let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $ #b ! pair nil 3 -senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] -senv7 = knownEnv - -descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"] -descr7 = DTop `DPush` (knownTy, SMerge) `DPush` (knownTy, SMerge) - -- A "neural network" except it's just scalars, not matrices. -- ps:((((), (R,R)), (R,R)), (R,R)) x:R -- |- let p1 = snd ps |