aboutsummaryrefslogtreecommitdiff
path: root/src/Simplify.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Simplify.hs')
-rw-r--r--src/Simplify.hs545
1 files changed, 422 insertions, 123 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs
index ac1bb8b..74b6601 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -1,8 +1,12 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
+{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
@@ -10,24 +14,31 @@
{-# LANGUAGE TypeOperators #-}
module Simplify (
simplifyN, simplifyFix,
- SimplifyConfig(..), simplifyWith, simplifyFixWith,
+ SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith,
) where
+import Control.Monad (ap)
+import Data.Bifunctor (first)
import Data.Function (fix)
import Data.Monoid (Any(..))
-import Data.Type.Equality (testEquality)
+
+import Debug.Trace
import AST
import AST.Count
-import CHAD.Types
+import AST.Pretty
+import AST.Sparse.Types
+import AST.UnMonoid (acPrjCompose)
import Data
+import Simplify.TH
--- | This has no fields now, hence this type is useless as-is. When debugging, however, it's useful to be able to add some.
data SimplifyConfig = SimplifyConfig
+ { scLogging :: Bool
+ }
defaultSimplifyConfig :: SimplifyConfig
-defaultSimplifyConfig = SimplifyConfig
+defaultSimplifyConfig = SimplifyConfig False
simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t
simplifyN 0 = id
@@ -37,13 +48,13 @@ simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t
simplify =
let ?accumInScope = checkAccumInScope @env knownEnv
?config = defaultSimplifyConfig
- in snd . simplify'
+ in snd . runSM . simplify'
simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t
simplifyWith config =
let ?accumInScope = checkAccumInScope @env knownEnv
?config = config
- in snd . simplify'
+ in snd . runSM . simplify'
simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t
simplifyFix = simplifyFixWith defaultSimplifyConfig
@@ -53,22 +64,74 @@ simplifyFixWith config =
let ?accumInScope = checkAccumInScope @env knownEnv
?config = config
in fix $ \loop e ->
- let (Any act, e') = simplify' e
+ let (act, e') = runSM (simplify' e)
in if act then loop e' else e'
-simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig) => Ex env t -> (Any, Ex env t)
-simplify' = \case
+-- | simplify monad
+newtype SM tenv tt env t a = SM ((Ex env t -> Ex tenv tt) -> (Any, a))
+ deriving (Functor)
+
+instance Applicative (SM tenv tt env t) where
+ pure x = SM (\_ -> (Any False, x))
+ (<*>) = ap
+
+instance Monad (SM tenv tt env t) where
+ SM f >>= g = SM $ \ctx -> f ctx >>= \x -> let SM h = g x in h ctx
+
+runSM :: SM env t env t a -> (Bool, a)
+runSM (SM f) = first getAny (f id)
+
+smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt)
+smReconstruct core = SM (\ctx -> (Any False, ctx core))
+
+class Monad m => ActedMonad m where
+ tellActed :: m ()
+ hideActed :: m a -> m a
+ liftActed :: (Any, a) -> m a
+
+instance ActedMonad ((,) Any) where
+ tellActed = (Any True, ())
+ hideActed (_, x) = (Any False, x)
+ liftActed = id
+
+instance ActedMonad (SM tenv tt env t) where
+ tellActed = SM (\_ -> tellActed)
+ hideActed (SM f) = SM (\ctx -> hideActed (f ctx))
+ liftActed pair = SM (\_ -> pair)
+
+-- more convenient in practice
+acted :: ActedMonad m => m a -> m a
+acted m = tellActed >> m
+
+within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a
+within subctx (SM f) = SM $ \ctx -> f (ctx . subctx)
+
+simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t)
+simplify' expr
+ | scLogging ?config = do
+ res <- simplify'Rec expr
+ full <- smReconstruct res
+ let printed = ppExpr knownEnv full
+ replace a bs = concatMap (\x -> if x == a then bs else [x])
+ str | '\n' `elem` printed = "--- simplify step:\n " ++ replace '\n' "\n " printed
+ | otherwise = "--- simplify step: " ++ printed
+ traceM str
+ return res
+ | otherwise = simplify'Rec expr
+
+simplify'Rec :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t)
+simplify'Rec = \case
-- inlining
ELet _ rhs body
| cheapExpr rhs
- -> acted $ simplify' (subst1 rhs body)
+ -> acted $ simplify' (substInline rhs body)
| Occ lexOcc runOcc <- occCount IZ body
, ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One) -- without effects, normal rules apply
|| (lexOcc == One && runOcc == One) -- with effects, linear inlining is still allowed, but weakening is not
- -> acted $ simplify' (subst1 rhs body)
+ -> acted $ simplify' (substInline rhs body)
- -- let splitting
+ -- let splitting / let peeling
ELet _ (EPair _ a b) body ->
acted $ simplify' $
ELet ext a $
@@ -76,13 +139,20 @@ simplify' = \case
subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ)
IS i -> EVar ext t (IS (IS i)))
body
+ ELet _ (EJust _ a) body ->
+ acted $ simplify' $ ELet ext a $ subst0 (EJust ext (EVar ext (typeOf a) IZ)) body
+ ELet _ (EInl _ t2 a) body ->
+ acted $ simplify' $ ELet ext a $ subst0 (EInl ext t2 (EVar ext (typeOf a) IZ)) body
+ ELet _ (EInr _ t1 a) body ->
+ acted $ simplify' $ ELet ext a $ subst0 (EInr ext t1 (EVar ext (typeOf a) IZ)) body
-- let rotation
- ELet _ (ELet _ rhs a) b ->
+ ELet _ (ELet _ rhs a) b -> do
+ b' <- within (ELet ext (ELet ext rhs a)) $ simplify' b
acted $ simplify' $
ELet ext rhs $
ELet ext a $
- weakenExpr (WCopy WSink) (snd (simplify' b))
+ weakenExpr (WCopy WSink) b'
-- beta rules for products
EFst _ (EPair _ e e')
@@ -100,12 +170,20 @@ simplify' = \case
EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1
EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1
- -- let floating to facilitate beta reduction
+ -- let floating
EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body))
ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body))
ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2)))
EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body))
EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e)))
+ EAccum _ t p e1 sp (ELet _ rhs body) acc ->
+ acted $ simplify' $
+ ELet ext rhs $
+ EAccum ext t p (weakenExpr WSink e1) sp body (weakenExpr WSink acc)
+
+ -- let () = e in () ~> e
+ ELet _ e1 (ENil _) | STNil <- typeOf e1 ->
+ acted $ simplify' e1
-- projection down-commuting
EFst _ (ECase _ e1 e2 e3) ->
@@ -114,89 +192,150 @@ simplify' = \case
ESnd _ (ECase _ e1 e2 e3) ->
acted $ simplify' $
ECase ext e1 (ESnd ext e2) (ESnd ext e3)
+ EFst _ (EMaybe _ e1 e2 e3) ->
+ acted $ simplify' $
+ EMaybe ext (EFst ext e1) (EFst ext e2) e3
+ ESnd _ (EMaybe _ e1 e2 e3) ->
+ acted $ simplify' $
+ EMaybe ext (ESnd ext e1) (ESnd ext e2) e3
+
+ -- TODO: more array indexing
+ EIdx _ (EReplicate1Inner _ _ e2) e3 -> acted $ simplify' $ EIdx ext e2 (EFst ext e3)
+ EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1
- -- TODO: array indexing (index of build, index of fold)
+ -- TODO: more array shape
+ EShape _ (EBuild _ _ e _) -> acted $ simplify' e
- -- TODO: beta rules for maybe
+ -- TODO: more constant folding
+ EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext))
+ EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext))
- -- TODO: constant folding for operations
+ -- inline cheap array constructors
+ ELet _ (EReplicate1Inner _ e1 e2) e3 ->
+ acted $ simplify' $
+ ELet ext (EPair ext e1 e2) $
+ let v = EVar ext (STPair tIx (typeOf e2)) IZ
+ in subst0 (EReplicate1Inner ext (EFst ext v) (ESnd ext v)) e3
+ -- -- TODO: This is a bad idea and anyway only helps in practice if (!) is
+ -- -- cheap, which it can't be because (!) is not cheap if you do AD after.
+ -- -- Should do proper SoA representation.
+ -- ELet _ (EBuild _ n e1 e2) e3 | cheapExpr e2 ->
+ -- acted $ simplify' $
+ -- ELet ext e1 $
+ -- subst0 (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) IZ) (weakenExpr (WCopy WSink) e2)) e3
+
+ -- eta rule for unit
+ e | STNil <- typeOf e, not ?accumInScope || not (hasAdds e) ->
+ case e of
+ ENil _ -> return e
+ _ -> acted $ return (ENil ext)
+
+ EBuild _ SZ _ e ->
+ acted $ simplify' $ EUnit ext (substInline (ENil ext) e)
-- monoid rules
- EAccum _ t p e1 e2 acc -> do
- acc' <- simplify' acc
- simplifyOneHotTerm (OneHotTerm t p e1 e2)
- (Any True, ENil ext)
- (\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc'))
- (\(OneHotTerm t' p' e1' e2') -> return (EAccum ext t' p' e1' e2' acc'))
- EPlus _ _ (EZero _ _) e -> acted $ simplify' e
- EPlus _ _ e (EZero _ _) -> acted $ simplify' e
- EOneHot _ t p e1 e2 ->
- simplifyOneHotTerm (OneHotTerm t p e1 e2)
- (Any True, EZero ext t)
- (\e -> (Any True, e))
- (\(OneHotTerm t' p' e1' e2') -> return (EOneHot ext t' p' e1' e2'))
+ EAccum _ t p e1 sp e2 acc -> do
+ e1' <- within (\e1' -> EAccum ext t p e1' sp e2 acc ) $ simplify' e1
+ e2' <- within (\e2' -> EAccum ext t p e1' sp e2' acc ) $ simplify' e2
+ acc' <- within (\acc' -> EAccum ext t p e1' sp e2' acc') $ simplify' acc
+ simplifyOHT (OneHotTerm SAID t p e1' sp e2')
+ (acted $ return (ENil ext))
+ (\sp' (InContext w wrap e) -> do
+ e' <- within (\e' -> wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')) $ simplify' e
+ return (wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')))
+ (\(InContext w wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> do
+ -- The acted management here is a hideous mess.
+ e1''' <- hideActed $ within (\e1''' -> wrap $ EAccum ext t' p' e1''' sp' e2'' (weakenExpr w acc')) $ simplify' e1''
+ e2''' <- hideActed $ within (\e2''' -> wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')) $ simplify' e2''
+ return (wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')))
+ EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e
+ EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e
+ EOneHot _ t p e1 e2 -> do
+ e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1
+ e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2
+ simplifyOHT (OneHotTerm SAIS t p e1' (spDense (acPrjTy p t)) e2')
+ (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2)))
+ (\sp' (InContext _ wrap e) ->
+ case isDense t sp' of
+ Just Refl -> do
+ e' <- hideActed $ within wrap $ simplify' e
+ return (wrap e')
+ Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse")
+ (\(InContext _ wrap (OneHotTerm _ t' p' e1'' sp' e2'')) ->
+ case isDense (acPrjTy p' t') sp' of
+ Just Refl -> do
+ e1''' <- hideActed $ within (\e1''' -> wrap $ EOneHot ext t' p' e1''' e2'') $ simplify' e1''
+ e2''' <- hideActed $ within (\e2''' -> wrap $ EOneHot ext t' p' e1''' e2''') $ simplify' e2''
+ return (wrap $ EOneHot ext t' p' e1''' e2''')
+ Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse")
-- type-specific equations for plus
- EPlus _ STNil _ _ -> (Any True, ENil ext)
+ EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) ->
+ acted $ return (ENil ext)
- EPlus _ (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) ->
- acted $ simplify' $ EJust ext (EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2))
- EPlus _ STPair{} ENothing{} e -> acted $ simplify' e
- EPlus _ STPair{} e ENothing{} -> acted $ simplify' e
+ EPlus _ (SMTPair t1 t2) (EPair _ a1 b1) (EPair _ a2 b2) ->
+ acted $ simplify' $ EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2)
- EPlus _ (STEither t1 _) (EJust _ (EInl _ dt2 a1)) (EJust _ (EInl _ _ a2)) ->
- acted $ simplify' $ EJust ext (EInl ext dt2 (EPlus ext t1 a1 a2))
- EPlus _ (STEither _ t2) (EJust _ (EInr _ dt1 b1)) (EJust _ (EInr _ _ b2)) ->
- acted $ simplify' $ EJust ext (EInr ext dt1 (EPlus ext t2 b1 b2))
- EPlus _ STEither{} ENothing{} e -> acted $ simplify' e
- EPlus _ STEither{} e ENothing{} -> acted $ simplify' e
+ EPlus _ (SMTLEither t1 _) (ELInl _ dt2 a1) (ELInl _ _ a2) ->
+ acted $ simplify' $ ELInl ext dt2 (EPlus ext t1 a1 a2)
+ EPlus _ (SMTLEither _ t2) (ELInr _ dt1 b1) (ELInr _ _ b2) ->
+ acted $ simplify' $ ELInr ext dt1 (EPlus ext t2 b1 b2)
+ EPlus _ SMTLEither{} ELNil{} e -> acted $ simplify' e
+ EPlus _ SMTLEither{} e ELNil{} -> acted $ simplify' e
- EPlus _ (STMaybe t) (EJust _ e1) (EJust _ e2) ->
+ EPlus _ (SMTMaybe t) (EJust _ e1) (EJust _ e2) ->
acted $ simplify' $ EJust ext (EPlus ext t e1 e2)
- EPlus _ STMaybe{} ENothing{} e -> acted $ simplify' e
- EPlus _ STMaybe{} e ENothing{} -> acted $ simplify' e
+ EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e
+ EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e
-- fallback recursion
EVar _ t i -> pure $ EVar ext t i
- ELet _ a b -> ELet ext <$> simplify' a <*> simplify' b
- EPair _ a b -> EPair ext <$> simplify' a <*> simplify' b
- EFst _ e -> EFst ext <$> simplify' e
- ESnd _ e -> ESnd ext <$> simplify' e
+ ELet _ a b -> [simprec| ELet ext *a *b |]
+ EPair _ a b -> [simprec| EPair ext *a *b |]
+ EFst _ e -> [simprec| EFst ext *e |]
+ ESnd _ e -> [simprec| ESnd ext *e |]
ENil _ -> pure $ ENil ext
- EInl _ t e -> EInl ext t <$> simplify' e
- EInr _ t e -> EInr ext t <$> simplify' e
- ECase _ e a b -> ECase ext <$> simplify' e <*> simplify' a <*> simplify' b
+ EInl _ t e -> [simprec| EInl ext t *e |]
+ EInr _ t e -> [simprec| EInr ext t *e |]
+ ECase _ e a b -> [simprec| ECase ext *e *a *b |]
ENothing _ t -> pure $ ENothing ext t
- EJust _ e -> EJust ext <$> simplify' e
- EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e
+ EJust _ e -> [simprec| EJust ext *e |]
+ EMaybe _ a b e -> [simprec| EMaybe ext *a *b *e |]
+ ELNil _ t1 t2 -> pure $ ELNil ext t1 t2
+ ELInl _ t e -> [simprec| ELInl ext t *e |]
+ ELInr _ t e -> [simprec| ELInr ext t *e |]
+ ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |]
EConstArr _ n t v -> pure $ EConstArr ext n t v
- EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b
- EFold1Inner _ a b c -> EFold1Inner ext <$> simplify' a <*> simplify' b <*> simplify' c
- ESum1Inner _ e -> ESum1Inner ext <$> simplify' e
- EUnit _ e -> EUnit ext <$> simplify' e
- EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b
- EMaximum1Inner _ e -> EMaximum1Inner ext <$> simplify' e
- EMinimum1Inner _ e -> EMinimum1Inner ext <$> simplify' e
+ EBuild _ n a b -> [simprec| EBuild ext n *a *b |]
+ EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |]
+ ESum1Inner _ e -> [simprec| ESum1Inner ext *e |]
+ EUnit _ e -> [simprec| EUnit ext *e |]
+ EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |]
+ EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |]
+ EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |]
EConst _ t v -> pure $ EConst ext t v
- EIdx0 _ e -> EIdx0 ext <$> simplify' e
- EIdx1 _ a b -> EIdx1 ext <$> simplify' a <*> simplify' b
- EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b
- EShape _ e -> EShape ext <$> simplify' e
- EOp _ op e -> EOp ext op <$> simplify' e
- ECustom _ s t p a b c e1 e2 ->
- ECustom ext s t p
- <$> (let ?accumInScope = False in simplify' a)
- <*> (let ?accumInScope = False in simplify' b)
- <*> (let ?accumInScope = False in simplify' c)
- <*> simplify' e1 <*> simplify' e2
- EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
- EZero _ t -> pure $ EZero ext t
- EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b
+ EIdx0 _ e -> [simprec| EIdx0 ext *e |]
+ EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |]
+ EIdx _ a b -> [simprec| EIdx ext *a *b |]
+ EShape _ e -> [simprec| EShape ext *e |]
+ EOp _ op e -> [simprec| EOp ext op *e |]
+ ECustom _ s t p a b c e1 e2 -> do
+ a' <- within (\a' -> ECustom ext s t p a' b c e1 e2) (let ?accumInScope = False in simplify' a)
+ b' <- within (\b' -> ECustom ext s t p a' b' c e1 e2) (let ?accumInScope = False in simplify' b)
+ c' <- within (\c' -> ECustom ext s t p a' b' c' e1 e2) (let ?accumInScope = False in simplify' c)
+ e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1)
+ e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2)
+ pure (ECustom ext s t p a' b' c' e1' e2')
+ ERecompute _ e -> [simprec| ERecompute ext *e |]
+ EWith _ t e1 e2 -> do
+ e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1)
+ e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2)
+ pure (EWith ext t e1' e2')
+ EZero _ t e -> [simprec| EZero ext t *e |]
+ EDeepZero _ t e -> [simprec| EDeepZero ext t *e |]
+ EPlus _ t a b -> [simprec| EPlus ext t *a *b |]
EError _ t s -> pure $ EError ext t s
-acted :: (Any, a) -> (Any, a)
-acted (_, x) = (Any True, x)
-
cheapExpr :: Expr x env t -> Bool
cheapExpr = \case
EVar{} -> True
@@ -204,6 +343,7 @@ cheapExpr = \case
EConst{} -> True
EFst _ e -> cheapExpr e
ESnd _ e -> cheapExpr e
+ EUnit _ e -> cheapExpr e
_ -> False
-- | This can be made more precise by tracking (and not counting) adds on
@@ -222,9 +362,13 @@ hasAdds = \case
ENothing _ _ -> False
EJust _ e -> hasAdds e
EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e
+ ELNil _ _ _ -> False
+ ELInl _ _ e -> hasAdds e
+ ELInr _ _ e -> hasAdds e
+ ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c
EConstArr _ _ _ _ -> False
EBuild _ _ a b -> hasAdds a || hasAdds b
- EFold1Inner _ a b c -> hasAdds a || hasAdds b || hasAdds c
+ EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
ESum1Inner _ e -> hasAdds e
EUnit _ e -> hasAdds e
EReplicate1Inner _ a b -> hasAdds a || hasAdds b
@@ -238,8 +382,10 @@ hasAdds = \case
EShape _ e -> hasAdds e
EOp _ _ e -> hasAdds e
EWith _ _ a b -> hasAdds a || hasAdds b
- EAccum _ _ _ _ _ _ -> True
- EZero _ _ -> False
+ ERecompute _ e -> hasAdds e
+ EAccum _ _ _ _ _ _ _ -> True
+ EZero _ _ e -> hasAdds e
+ EDeepZero _ _ e -> hasAdds e
EPlus _ _ a b -> hasAdds a || hasAdds b
EOneHot _ _ _ a b -> hasAdds a || hasAdds b
EError _ _ _ -> False
@@ -252,49 +398,202 @@ checkAccumInScope = \case SNil -> False
check STNil = False
check (STPair s t) = check s || check t
check (STEither s t) = check s || check t
+ check (STLEither s t) = check s || check t
check (STMaybe t) = check t
check (STArr _ t) = check t
check (STScal _) = False
check STAccum{} = True
-data OneHotTerm env p a b where
- OneHotTerm :: STy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env (D2 b) -> OneHotTerm env p a b
-deriving instance Show (OneHotTerm env p a b)
-
-simplifyOneHotTerm :: OneHotTerm env p a b
- -> (Any, r) -- ^ Zero case (onehot is actually zero)
- -> (Ex env (D2 a) -> (Any, r)) -- ^ Trivial case (no zeros in onehot)
- -> (forall p' b'. OneHotTerm env p' a b' -> (Any, r))
- -> (Any, r)
-simplifyOneHotTerm (OneHotTerm _ _ _ (EZero _ _)) kzero _ _ = kzero
-simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k
- | Just Refl <- testEquality (acPrjTy prj1 t1) t2
- = do (Any True, ()) -- record, whatever happens later, that we've modified something
- concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
- simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k
-simplifyOneHotTerm (OneHotTerm _ SAPHere _ e) _ ktriv _ = ktriv e
-simplifyOneHotTerm term _ _ k = k term
-
-concatOneHots :: STy a
- -> SAcPrj p1 a b -> Ex env (AcIdx p1 a)
- -> SAcPrj p2 b c -> Ex env (AcIdx p2 b)
- -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r
-concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of
- (_, SAPHere) -> k prj2 idx2
-
- (STPair a _, SAPFst prj1') ->
- concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPFst prj12) idx12
- (STPair _ b, SAPSnd prj1') ->
- concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPSnd prj12) idx12
-
- (STEither a _, SAPLeft prj1') ->
- concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12
- (STEither _ b, SAPRight prj1') ->
- concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12
-
- (STMaybe a, SAPJust prj1') ->
- concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12
-
- (STArr n a, SAPArrIdx prj1' _) ->
- concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
- k (SAPArrIdx prj12 n) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)
+data OneHotTerm dense env a where
+ OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a
+deriving instance Show (OneHotTerm dense env a)
+
+data InContext f env (a :: Ty) where
+ InContext :: env :> env' -> (forall t. Ex env' t -> Ex env t) -> f env' a -> InContext f env a
+
+simplifyOHT_recogniseMonoid :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a)
+simplifyOHT_recogniseMonoid (OneHotTerm dense t prj idx sp val) = do
+ val' <- liftActed $ recogniseMonoid (applySparse sp (acPrjTy prj t)) val
+ return $ OneHotTerm dense t prj idx sp val'
+
+simplifyOHT_unsparse :: ActedMonad m => OneHotTerm dense env a -> m (InContext (OneHotTerm dense) env a)
+simplifyOHT_unsparse (OneHotTerm SAID t prj1 idx1 sp1 val1) =
+ unsparseOneHotD sp1 val1 $ \w wrap prj2 idx2 sp2 val2 ->
+ acPrjCompose SAID prj1 (weakenExpr w idx1) prj2 idx2 $ \prj' idx' ->
+ return $ InContext w wrap (OneHotTerm SAID t prj' idx' sp2 val2)
+simplifyOHT_unsparse oht@(OneHotTerm SAIS _ _ _ _ _) = return $ InContext WId id oht
+
+simplifyOHT_concat :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a)
+simplifyOHT_concat (OneHotTerm @dense @_ @_ @_ @env dense t1 prj1 idx1 sp (EOneHot @_ @c @p2 _ t2 prj2 idx2 val))
+ | Just Refl <- isDense (acPrjTy prj1 t1) sp =
+ let idx2' :: Ex env (AcIdx dense p2 c)
+ idx2' = case dense of
+ SAID -> reduceAcIdx t2 prj2 idx2
+ SAIS -> idx2
+ in acPrjCompose dense prj1 idx1 prj2 idx2' $ \prj' idx' ->
+ acted $ return $ OneHotTerm dense t1 prj' idx' (spDense (acPrjTy prj' t1)) val
+simplifyOHT_concat oht = return oht
+
+-- -- Property not expressed in types: if the Sparse in the input OneHotTerm is
+-- -- dense, then the Sparse in the output will also be dense. This property is
+-- -- used when simplifying EOneHot, which cannot represent sparsity.
+simplifyOHT :: ActedMonad m => OneHotTerm dense env a
+ -> m r -- ^ Zero case (onehot is actually zero)
+ -> (forall b. Sparse a b -> InContext Ex env b -> m r) -- ^ Trivial case (no zeros in onehot)
+ -> (InContext (OneHotTerm dense) env a -> m r) -- ^ Simplified
+ -> m r
+simplifyOHT oht kzero ktriv k = do
+ -- traceM $ "sOHT: input " ++ show oht
+ oht1 <- simplifyOHT_recogniseMonoid oht
+ -- traceM $ "sOHT: recog " ++ show oht1
+ InContext w1 wrap1 oht2 <- simplifyOHT_unsparse oht1
+ -- traceM $ "sOHT: unspa " ++ show oht2
+ oht3 <- simplifyOHT_concat oht2
+ -- traceM $ "sOHT: conca " ++ show oht3
+ -- traceM ""
+ case oht3 of
+ OneHotTerm _ _ _ _ _ EZero{} -> kzero
+ OneHotTerm _ _ SAPHere _ sp val -> ktriv sp (InContext w1 wrap1 val)
+ _ -> k (InContext w1 wrap1 oht3)
+
+-- Sets the acted flag whenever a non-trivial projection is returned or the
+-- output Sparse is different from the input Sparse.
+unsparseOneHotD :: ActedMonad m => Sparse a a' -> Ex env a'
+ -> (forall p b c env'. env :> env' -> (forall s. Ex env' s -> Ex env s)
+ -> SAcPrj p a b -> Ex env' (AcIdxD p a) -> Sparse b c -> Ex env' c -> m r) -> m r
+unsparseOneHotD topsp topval k = case (topsp, topval) of
+ -- eliminate always-Just sparse onehot
+ (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) ->
+ acted $ unsparseOneHotD s (EOneHot ext t prj idx val) k
+
+ -- expand the top levels of a onehot for a sparse type into a onehot for the
+ -- corresponding non-sparse type
+ (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj (efst idx) val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPFst spprj) idx' s1' e'
+ (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) ->
+ unsparseOneHotD s2 (EOneHot ext t2 prj (esnd idx) val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPSnd spprj) idx' s1' e'
+ (SpLEither s1 _, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPLeft spprj) idx' s1' e'
+ (SpLEither _ s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) ->
+ unsparseOneHotD s2 (EOneHot ext t2 prj idx val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPRight spprj) idx' s1' e'
+ (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPJust spprj) idx' s1' e'
+ (SpArr s1, EOneHot _ (SMTArr _ t1) (SAPArrIdx prj) idx val)
+ | Dict <- styKnown (typeOf idx) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \w wrap spprj idx' s1' e' ->
+ acted $ k (w .> WSink) (elet idx . wrap) (SAPArrIdx spprj) (EPair ext (efst (efst (evar (w @> IZ)))) idx') s1' e'
+
+ -- anything else we don't know how to improve
+ _ -> k WId id SAPHere (ENil ext) topsp topval
+
+{-
+unsparseOneHotS :: ActedMonad m
+ => Sparse a a' -> Ex env a'
+ -> (forall b. Sparse a b -> Ex env b -> m r) -> m r
+unsparseOneHotS topsp topval k = case (topsp, topval) of
+ -- order is relevant to make sure we set the acted flag correctly
+ (SpAbsent, v@ENil{}) -> k SpAbsent v
+ (SpAbsent, v@EZero{}) -> k SpAbsent v
+ (SpAbsent, _) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
+ (_, EZero{}) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
+ (sp, _) | isAbsent sp -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
+
+ -- the unsparsifying
+ (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) ->
+ acted $ unsparseOneHotS s (EOneHot ext t prj idx val) k
+
+ -- recursion
+ -- TODO: coproducts could safely become projections as they do not need
+ -- zeroinfo. But that would only work if the coproduct is at the top, because
+ -- as soon as we hit a product, we need zeroinfo to make it a projection and
+ -- we don't have that.
+ (SpSparse s, e) -> k (SpSparse s) e
+ (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj (efst idx) val) $ \s1' e' ->
+ acted $ k (SpPair s1' SpAbsent) (EPair ext e' (ENil ext))
+ (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) ->
+ unsparseOneHotS s2 (EOneHot ext t2 prj (esnd idx) val) $ \s2' e' ->
+ acted $ k (SpPair SpAbsent s2') (EPair ext (ENil ext) e')
+ (SpLEither s1 s2, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> do
+ case s2 of SpAbsent -> pure () ; _ -> tellActed
+ k (SpLEither s1' SpAbsent) (ELInl ext STNil e')
+ (SpLEither s1 s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) ->
+ unsparseOneHotS s2 (EOneHot ext t2 prj idx val) $ \s2' e' -> do
+ case s1 of SpAbsent -> pure () ; _ -> tellActed
+ acted $ k (SpLEither SpAbsent s2') (ELInr ext STNil e')
+ (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' ->
+ k (SpMaybe s1') (EJust ext e')
+ (SpArr s1, EOneHot _ (SMTArr n t1) (SAPArrIdx prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \s1' e' ->
+ k (SpArr s1') (elet idx $ EOneHot ext (SMTArr n (applySparse s1' _)) (SAPArrIdx SAPHere) (EPair ext (efst (evar IZ)) (ENil ext)) e')
+ _ -> _
+-}
+
+-- | Recognises 'EZero' and 'EOneHot'.
+recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t)
+recogniseMonoid _ e@EOneHot{} = return e
+recogniseMonoid SMTNil (ENil _) = acted $ return $ EZero ext SMTNil (ENil ext)
+recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) =
+ ((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case
+ (EZero _ _ ezi1, EZero _ _ ezi2) -> acted $ return $ EZero ext typ (EPair ext ezi1 ezi2)
+ (a', EZero _ _ ezi2) -> acted $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a'
+ (EZero _ _ ezi1, b') -> acted $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b'
+ (a', b') -> return $ EPair ext a' b'
+recogniseMonoid typ@(SMTLEither t1 t2) expr =
+ case expr of
+ ELNil{} -> acted $ return $ EZero ext typ (ENil ext)
+ ELInl _ _ e -> acted $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e
+ ELInr _ _ e -> acted $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e
+ _ -> return expr
+recogniseMonoid typ@(SMTMaybe t1) expr =
+ case expr of
+ ENothing{} -> acted $ return $ EZero ext typ (ENil ext)
+ EJust _ e -> acted $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e
+ _ -> return expr
+recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) =
+ acted $ do
+ e' <- recogniseMonoid t e
+ return $
+ ELet ext e' $
+ EOneHot ext typ (SAPArrIdx SAPHere)
+ (EPair ext (EPair ext (ENil ext) (EUnit ext (makeZeroInfo t (EVar ext (fromSMTy t) IZ))))
+ (ENil ext))
+ (EVar ext (fromSMTy t) IZ)
+recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of
+ (STI32, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ (STI64, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ (STF32, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ (STF64, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ _ -> return e
+recogniseMonoid _ e = return e
+
+reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdxS p a) -> Ex env (AcIdxD p a)
+reduceAcIdx topty topprj e = case (topty, topprj) of
+ (_, SAPHere) -> ENil ext
+ (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e)
+ (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e)
+ (SMTLEither t1 _ , SAPLeft p) -> reduceAcIdx t1 p e
+ (SMTLEither _ t2, SAPRight p) -> reduceAcIdx t2 p e
+ (SMTMaybe t1, SAPJust p) -> reduceAcIdx t1 p e
+ (SMTArr _ t, SAPArrIdx p) ->
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (efst e1) (reduceAcIdx t p e2)
+
+zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t)
+zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e)
+ where
+ -- invariant: AcIdx expression is duplicable
+ go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t)
+ go t SAPHere _ e = makeZeroInfo t e
+ go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx)
+ go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e)
+ go SMTLEither{} _ _ _ = ENil ext
+ go SMTMaybe{} _ _ _ = ENil ext
+ go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx)