diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 12:39:08 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-10 12:39:08 +0100 |
commit | 013e01e28aba090c065ed584671a65aa339ea51b (patch) | |
tree | 1595a8363fc181a13d41224e206d051d4e6a906b /src/CHAD.hs | |
parent | 9c3f3c4a5f1258c99aefc95944af79dd6da2586c (diff) |
Test GMM; it fails
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 101 |
1 files changed, 79 insertions, 22 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 59d61a7..8b9f17a 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -28,7 +28,6 @@ module CHAD ( Select, ) where -import Data.Bifunctor (first, second) import Data.Functor.Const import Data.Kind (Type) import GHC.Stack (HasCallStack) @@ -198,6 +197,7 @@ type Storage :: Symbol -> Type data Storage s where SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator SMerge :: Storage "merge" -- ^ just return and merge + SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions deriving instance Show (Storage s) -- | Environment description @@ -224,11 +224,13 @@ subDescr (des `DPush` (t, sto)) (SEYes sub) k = case sto of SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) + SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e) subDescr (des `DPush` (_, sto)) (SENo sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) + SDiscr -> k des' submerge subaccum (SENo subd1e) -- | Select only the types from the environment that have the specified storage type family Select env sto s where @@ -240,8 +242,13 @@ select :: Storage s -> Descr env sto -> SList STy (Select env sto s) select _ DTop = SNil select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) select s@SMerge (DPush des (_, SAccum)) = select s des +select s@SDiscr (DPush des (_, SAccum)) = select s des select s@SAccum (DPush des (_, SMerge)) = select s des select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) +select s@SDiscr (DPush des (_, SMerge)) = select s des +select s@SAccum (DPush des (_, SDiscr)) = select s des +select s@SMerge (DPush des (_, SDiscr)) = select s des +select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des) ---------------------------------- DERIVATIVES --------------------------------- @@ -338,13 +345,27 @@ conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) conv1Idx IZ = IZ conv1Idx (IS i) = IS (conv1Idx i) -conv2Idx :: Descr env sto -> Idx env t - -> Either (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) - (Idx (Select env sto "merge") t) -conv2Idx (DPush _ (_, SAccum)) IZ = Left IZ -conv2Idx (DPush _ (_, SMerge)) IZ = Right IZ -conv2Idx (DPush des (_, SAccum)) (IS i) = first IS (conv2Idx des i) -conv2Idx (DPush des (_, SMerge)) (IS i) = second IS (conv2Idx des i) +data Idx2 env sto t + = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) + | Idx2Me (Idx (Select env sto "merge") t) + | Idx2Di (Idx (Select env sto "discr") t) + +conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t +conv2Idx (DPush _ (_, SAccum)) IZ = Idx2Ac IZ +conv2Idx (DPush _ (_, SMerge)) IZ = Idx2Me IZ +conv2Idx (DPush _ (_, SDiscr)) IZ = Idx2Di IZ +conv2Idx (DPush des (_, SAccum)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j) + Idx2Me j -> Idx2Me j + Idx2Di j -> Idx2Di j +conv2Idx (DPush des (_, SMerge)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac j + Idx2Me j -> Idx2Me (IS j) + Idx2Di j -> Idx2Di j +conv2Idx (DPush des (_, SDiscr)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac j + Idx2Me j -> Idx2Me j + Idx2Di j -> Idx2Di (IS j) conv2Idx DTop i = case i of {} @@ -536,6 +557,13 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = (#acc :++: (#d :++: #shb :++: #tl))) SMerge -> case t of + -- Discrete values are left as-is + _ | isDiscrete t -> + k (storepl `DPush` (t, SDiscr)) + envpro + (SENo prosub) + wf + -- Arrays with "merge" storage are promoted to an accumulator in envPro STArr (arrn :: SNat arrn) (arrt :: STy arrt) -> k (storepl `DPush` (t, SAccum)) @@ -555,18 +583,40 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = .> WPick @(TAccum (D2 (TArr arrn arrt))) @(D2 dt : shbinds) (Const () `SCons` shbindsC) (WId @(D2AcE (Select env1 stoRepl "accum")))) - -- "merge" values must be an array, so reject everything else. (TODO: generalise this) + -- "merge" values must be an array or fully discrete, so reject everything else. (TODO: generalise this) _ -> - error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t - -- where - -- containsTArr :: STy t' -> Bool - -- containsTArr = \case - -- STNil -> False - -- STPair a b -> containsTArr a || containsTArr b - -- STEither a b -> containsTArr a || containsTArr b - -- STArr{} -> True - -- STScal{} -> False - -- STAccum{} -> error "An accumulator in merge storage?" + error $ "Closure variable of 'build'-like thing contains a non-array non-discrete SMerge value: " ++ show t + + -- Discrete values are left as-is, nothing to do + SDiscr -> + k (storepl `DPush` (t, SDiscr)) + envpro + prosub + wf + where + isDiscrete :: STy t' -> Bool + isDiscrete = \case + STNil -> True + STPair a b -> isDiscrete a && isDiscrete b + STEither a b -> isDiscrete a && isDiscrete b + STMaybe a -> isDiscrete a + STArr _ a -> isDiscrete a + STScal st -> case st of + STI32 -> True + STI64 -> True + STF32 -> False + STF64 -> False + STBool -> True + STAccum{} -> False + + -- containsTArr :: STy t' -> Bool + -- containsTArr = \case + -- STNil -> False + -- STPair a b -> containsTArr a || containsTArr b + -- STEither a b -> containsTArr a || containsTArr b + -- STArr{} -> True + -- STScal{} -> False + -- STAccum{} -> error "An accumulator in merge storage?" makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) makeAccumulators SNil e = e @@ -682,20 +732,27 @@ drev :: forall env sto t. drev des = \case EVar _ t i -> case conv2Idx des i of - Left accI -> + Idx2Ac accI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) (subenvNone (select SMerge des)) (EAccum SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI))) - Right tupI -> + Idx2Me tupI -> Ret BTop SETop (EVar ext (d1 t) (conv1Idx i)) (subenvOnehot (select SMerge des) tupI) (EPair ext (ENil ext) (EVar ext (d2 t) IZ)) + Idx2Di _ -> + Ret BTop + SETop + (EVar ext (d1 t) (conv1Idx i)) + (subenvNone (select SMerge des)) + (ENil ext) + ELet _ rhs body | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs @@ -929,7 +986,7 @@ drev des = \case let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> - case drev (prodes `DPush` (shty, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> + case drev (prodes `DPush` (shty, SDiscr)) e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> case assertSubenvEmpty sub of { Refl -> let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in let collectexpr = bindingsCollect e0 subtapeE in |