aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/AST.hs3
-rw-r--r--src/AST/Count.hs81
-rw-r--r--src/AST/Types.hs28
-rw-r--r--src/CHAD.hs6
-rw-r--r--src/CHAD/Top.hs4
-rw-r--r--src/Compile.hs2
6 files changed, 73 insertions, 51 deletions
diff --git a/src/AST.hs b/src/AST.hs
index b8bee1b..a10f1ae 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -82,6 +82,9 @@ data Expr x env t where
-- be backpropagated to; 'a' is the inactive part. The dual field of
-- ECustom does not allow a derivative to be generated for 'a', and hence
-- none is propagated.
+ -- No accumulators are allowed inside a, b and tape. This restriction is
+ -- currently not used very much, so could be relaxed in the future; be sure
+ -- to check this requirement whenever it is necessary for soundness!
ECustom :: x t -> STy a -> STy b -> STy tape
-> Expr x [b, a] t -- ^ regular operation
-> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 5666289..bec5a9d 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -561,23 +561,7 @@ occCountX initialS topexpr k = case topexpr of
Just Refl -> expr
Nothing -> error "unreachable"
- ESum1Inner _ e
- | STArr (SS n) _ <- typeOf e ->
- case s of
- SsNone ->
- occCountX SsNone e $ \env mke ->
- k env $ \env' ->
- use (mke env') $ ENil ext
- SsArr SsNone ->
- occCountX (SsArr SsNone) e $ \env mke ->
- k env $ \env' ->
- elet (mke env') $
- EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext)
- SsArr SsFull ->
- occCountX (SsArr SsFull) e $ \env mke ->
- k env $ \env' ->
- ESum1Inner ext (mke env')
- SsFull -> occCountX (SsArr SsFull) topexpr k
+ ESum1Inner _ e -> handleReduction (ESum1Inner ext) e
EUnit _ e ->
case s of
@@ -607,14 +591,8 @@ occCountX initialS topexpr k = case topexpr of
EReplicate1Inner ext (mka env') (mkb env')
SsFull -> occCountX (SsArr SsFull) topexpr k
- {-
- EMaximum1Inner _ e ->
- let (e', env) = re (SsArr (ssUnarr s)) e
- in (EMaximum1Inner s e', env)
- EMinimum1Inner _ e ->
- let (e', env) = re (SsArr (ssUnarr s)) e
- in (EMinimum1Inner s e', env)
- -}
+ EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e
+ EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e
EConst _ t x ->
k OccEnd $ \_ ->
@@ -680,15 +658,26 @@ occCountX initialS topexpr k = case topexpr of
k env1 $ \env' ->
projectSmallerSubstruc SsFull s $ EOp ext op (mke env')
- {-
- ECustom _ t1 t2 t3 e1 e2 e3 a b ->
- let (e1', _) = occCountX SsFull e1
- (e2', _) = occCountX SsFull e2
- (e3', _) = occCountX SsFull e3
- (a', env1) = re SsFull a -- let's be pessimistic here for safety
- (b', env2) = re SsFull b
- in (ECustom SsFull t1 t2 t3 e1' e2' e3' a' b', env1 <> env2)
- -}
+ ECustom _ t1 t2 t3 e1 e2 e3 a b
+ | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 ->
+ error "Accumulators not allowed in input/output/tape of an ECustom"
+ | otherwise ->
+ case s of
+ SsNone ->
+ -- Allowed to ignore e1/e2/e3 here because no accumulators are
+ -- communicated, and hence no relevant effects exist
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ s' -> -- Let's be pessimistic for safety
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s' $
+ ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env')
ERecompute _ e ->
occCountX s e $ \env1 mke ->
@@ -781,11 +770,31 @@ occCountX initialS topexpr k = case topexpr of
EError _ t msg ->
k OccEnd $ \_ -> EError ext (applySubstruc s t) msg
-
- _ -> error "occCountX: TODO unimplemented"
where
s = simplifySubstruc (typeOf topexpr) initialS
+ handleReduction :: t ~ TArr n (TScal t2)
+ => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2)))
+ -> Expr x env (TArr (S n) (TScal t2))
+ -> r
+ handleReduction reduce e
+ | STArr (SS n) _ <- typeOf e =
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env mke ->
+ k env $ \env' ->
+ use (mke env') $ ENil ext
+ SsArr SsNone ->
+ occCountX (SsArr SsNone) e $ \env mke ->
+ k env $ \env' ->
+ elet (mke env') $
+ EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext)
+ SsArr SsFull ->
+ occCountX (SsArr SsFull) e $ \env mke ->
+ k env $ \env' ->
+ reduce (mke env')
+ SsFull -> occCountX (SsArr SsFull) topexpr k
+
deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r
deleteUnused SNil (Some OccEnd) k = k SETop
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index 42bfb92..4ddcb50 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -171,15 +171,25 @@ type family ScalIsIntegral t where
ScalIsIntegral TBool = False
-- | Returns true for arrays /and/ accumulators.
-hasArrays :: STy t' -> Bool
-hasArrays STNil = False
-hasArrays (STPair a b) = hasArrays a || hasArrays b
-hasArrays (STEither a b) = hasArrays a || hasArrays b
-hasArrays (STLEither a b) = hasArrays a || hasArrays b
-hasArrays (STMaybe t) = hasArrays t
-hasArrays STArr{} = True
-hasArrays STScal{} = False
-hasArrays STAccum{} = True
+typeHasArrays :: STy t' -> Bool
+typeHasArrays STNil = False
+typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STMaybe t) = typeHasArrays t
+typeHasArrays STArr{} = True
+typeHasArrays STScal{} = False
+typeHasArrays STAccum{} = True
+
+typeHasAccums :: STy t' -> Bool
+typeHasAccums STNil = False
+typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STMaybe t) = typeHasAccums t
+typeHasAccums STArr{} = False
+typeHasAccums STScal{} = False
+typeHasAccums STAccum{} = True
type family Tup env where
Tup '[] = TNil
diff --git a/src/CHAD.hs b/src/CHAD.hs
index dcb10aa..08c0a2f 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -782,7 +782,7 @@ drev des accumMap sd = \case
(ENil ext)
ELet _ (rhs :: Expr _ _ a) body
- | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
+ | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge
, RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body
, Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs
, let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
@@ -881,8 +881,8 @@ drev des accumMap sd = \case
ECase _ e (a :: Expr _ _ t) b
| STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e
- , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
- , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
+ , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && typeHasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge
+ , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && typeHasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge
, let (bindids1, bindids2) = validSplitEither (extOf e)
, RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2
<- drevScoped des accumMap t1 storage1 bindids1 sd a
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index 484779e..4814bdf 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -43,8 +43,8 @@ mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl
accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r
accumDescr SNil k = k DTop
accumDescr (t `SCons` env) k = accumDescr env $ \des ->
- if hasArrays t then k (des `DPush` (t, Nothing, SAccum))
- else k (des `DPush` (t, Nothing, SMerge))
+ if typeHasArrays t then k (des `DPush` (t, Nothing, SAccum))
+ else k (des `DPush` (t, Nothing, SMerge))
reassembleD2E :: Descr env sto
-> D1E env :> env'
diff --git a/src/Compile.hs b/src/Compile.hs
index a5c4fb7..2b7cd9e 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -1481,7 +1481,7 @@ copyForWriting topty var = case topty of
-- nesting we'd have to check the refcounts of all the nested arrays _too_;
-- let's not do that. Furthermore, no sub-arrays means that the whole thing
-- is flat, and we can just memcpy if necessary.
- SMTArr n t | not (hasArrays (fromSMTy t)) -> do
+ SMTArr n t | not (typeHasArrays (fromSMTy t)) -> do
name <- genName
shszname <- genName' "shsz"
emit $ SVarDeclUninit toptyname name