diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-10-08 21:42:05 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-10-08 21:42:05 +0200 |
commit | 947ea218c54cd12f64320f6207fc8da83f7a8c43 (patch) | |
tree | 8e0c11b69673345e66b092bd792ee3a0b4f101aa | |
parent | 63f6e894fea21813104337b1c6a2507f59337090 (diff) |
Complete occCountX
-rw-r--r-- | src/AST.hs | 3 | ||||
-rw-r--r-- | src/AST/Count.hs | 81 | ||||
-rw-r--r-- | src/AST/Types.hs | 28 | ||||
-rw-r--r-- | src/CHAD.hs | 6 | ||||
-rw-r--r-- | src/CHAD/Top.hs | 4 | ||||
-rw-r--r-- | src/Compile.hs | 2 |
6 files changed, 73 insertions, 51 deletions
@@ -82,6 +82,9 @@ data Expr x env t where -- be backpropagated to; 'a' is the inactive part. The dual field of -- ECustom does not allow a derivative to be generated for 'a', and hence -- none is propagated. + -- No accumulators are allowed inside a, b and tape. This restriction is + -- currently not used very much, so could be relaxed in the future; be sure + -- to check this requirement whenever it is necessary for soundness! ECustom :: x t -> STy a -> STy b -> STy tape -> Expr x [b, a] t -- ^ regular operation -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 5666289..bec5a9d 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -561,23 +561,7 @@ occCountX initialS topexpr k = case topexpr of Just Refl -> expr Nothing -> error "unreachable" - ESum1Inner _ e - | STArr (SS n) _ <- typeOf e -> - case s of - SsNone -> - occCountX SsNone e $ \env mke -> - k env $ \env' -> - use (mke env') $ ENil ext - SsArr SsNone -> - occCountX (SsArr SsNone) e $ \env mke -> - k env $ \env' -> - elet (mke env') $ - EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) - SsArr SsFull -> - occCountX (SsArr SsFull) e $ \env mke -> - k env $ \env' -> - ESum1Inner ext (mke env') - SsFull -> occCountX (SsArr SsFull) topexpr k + ESum1Inner _ e -> handleReduction (ESum1Inner ext) e EUnit _ e -> case s of @@ -607,14 +591,8 @@ occCountX initialS topexpr k = case topexpr of EReplicate1Inner ext (mka env') (mkb env') SsFull -> occCountX (SsArr SsFull) topexpr k - {- - EMaximum1Inner _ e -> - let (e', env) = re (SsArr (ssUnarr s)) e - in (EMaximum1Inner s e', env) - EMinimum1Inner _ e -> - let (e', env) = re (SsArr (ssUnarr s)) e - in (EMinimum1Inner s e', env) - -} + EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e + EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e EConst _ t x -> k OccEnd $ \_ -> @@ -680,15 +658,26 @@ occCountX initialS topexpr k = case topexpr of k env1 $ \env' -> projectSmallerSubstruc SsFull s $ EOp ext op (mke env') - {- - ECustom _ t1 t2 t3 e1 e2 e3 a b -> - let (e1', _) = occCountX SsFull e1 - (e2', _) = occCountX SsFull e2 - (e3', _) = occCountX SsFull e3 - (a', env1) = re SsFull a -- let's be pessimistic here for safety - (b', env2) = re SsFull b - in (ECustom SsFull t1 t2 t3 e1' e2' e3' a' b', env1 <> env2) - -} + ECustom _ t1 t2 t3 e1 e2 e3 a b + | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 -> + error "Accumulators not allowed in input/output/tape of an ECustom" + | otherwise -> + case s of + SsNone -> + -- Allowed to ignore e1/e2/e3 here because no accumulators are + -- communicated, and hence no relevant effects exist + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + s' -> -- Let's be pessimistic for safety + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s' $ + ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env') ERecompute _ e -> occCountX s e $ \env1 mke -> @@ -781,11 +770,31 @@ occCountX initialS topexpr k = case topexpr of EError _ t msg -> k OccEnd $ \_ -> EError ext (applySubstruc s t) msg - - _ -> error "occCountX: TODO unimplemented" where s = simplifySubstruc (typeOf topexpr) initialS + handleReduction :: t ~ TArr n (TScal t2) + => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2))) + -> Expr x env (TArr (S n) (TScal t2)) + -> r + handleReduction reduce e + | STArr (SS n) _ <- typeOf e = + case s of + SsNone -> + occCountX SsNone e $ \env mke -> + k env $ \env' -> + use (mke env') $ ENil ext + SsArr SsNone -> + occCountX (SsArr SsNone) e $ \env mke -> + k env $ \env' -> + elet (mke env') $ + EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) + SsArr SsFull -> + occCountX (SsArr SsFull) e $ \env mke -> + k env $ \env' -> + reduce (mke env') + SsFull -> occCountX (SsArr SsFull) topexpr k + deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r deleteUnused SNil (Some OccEnd) k = k SETop diff --git a/src/AST/Types.hs b/src/AST/Types.hs index 42bfb92..4ddcb50 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -171,15 +171,25 @@ type family ScalIsIntegral t where ScalIsIntegral TBool = False -- | Returns true for arrays /and/ accumulators. -hasArrays :: STy t' -> Bool -hasArrays STNil = False -hasArrays (STPair a b) = hasArrays a || hasArrays b -hasArrays (STEither a b) = hasArrays a || hasArrays b -hasArrays (STLEither a b) = hasArrays a || hasArrays b -hasArrays (STMaybe t) = hasArrays t -hasArrays STArr{} = True -hasArrays STScal{} = False -hasArrays STAccum{} = True +typeHasArrays :: STy t' -> Bool +typeHasArrays STNil = False +typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STMaybe t) = typeHasArrays t +typeHasArrays STArr{} = True +typeHasArrays STScal{} = False +typeHasArrays STAccum{} = True + +typeHasAccums :: STy t' -> Bool +typeHasAccums STNil = False +typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STMaybe t) = typeHasAccums t +typeHasAccums STArr{} = False +typeHasAccums STScal{} = False +typeHasAccums STAccum{} = True type family Tup env where Tup '[] = TNil diff --git a/src/CHAD.hs b/src/CHAD.hs index dcb10aa..08c0a2f 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -782,7 +782,7 @@ drev des accumMap sd = \case (ENil ext) ELet _ (rhs :: Expr _ _ a) body - | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge + | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 @@ -881,8 +881,8 @@ drev des accumMap sd = \case ECase _ e (a :: Expr _ _ t) b | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e - , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge - , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge + , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && typeHasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge + , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && typeHasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge , let (bindids1, bindids2) = validSplitEither (extOf e) , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2 <- drevScoped des accumMap t1 storage1 bindids1 sd a diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs index 484779e..4814bdf 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Top.hs @@ -43,8 +43,8 @@ mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r accumDescr SNil k = k DTop accumDescr (t `SCons` env) k = accumDescr env $ \des -> - if hasArrays t then k (des `DPush` (t, Nothing, SAccum)) - else k (des `DPush` (t, Nothing, SMerge)) + if typeHasArrays t then k (des `DPush` (t, Nothing, SAccum)) + else k (des `DPush` (t, Nothing, SMerge)) reassembleD2E :: Descr env sto -> D1E env :> env' diff --git a/src/Compile.hs b/src/Compile.hs index a5c4fb7..2b7cd9e 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1481,7 +1481,7 @@ copyForWriting topty var = case topty of -- nesting we'd have to check the refcounts of all the nested arrays _too_; -- let's not do that. Furthermore, no sub-arrays means that the whole thing -- is flat, and we can just memcpy if necessary. - SMTArr n t | not (hasArrays (fromSMTy t)) -> do + SMTArr n t | not (typeHasArrays (fromSMTy t)) -> do name <- genName shszname <- genName' "shsz" emit $ SVarDeclUninit toptyname name |