aboutsummaryrefslogtreecommitdiff
path: root/src/Simplify.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-10 21:49:45 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-10 21:50:25 +0100
commit174af2ba568de66e0d890825b8bda930b8e7bb96 (patch)
tree5a20f52662e87ff7cf6a6bef5db0713aa6c7884e /src/Simplify.hs
parent92bca235e3aaa287286b6af082d3fce585825a35 (diff)
Move module hierarchy under CHAD.
Diffstat (limited to 'src/Simplify.hs')
-rw-r--r--src/Simplify.hs619
1 files changed, 0 insertions, 619 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs
deleted file mode 100644
index 19d0c17..0000000
--- a/src/Simplify.hs
+++ /dev/null
@@ -1,619 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveFunctor #-}
-{-# LANGUAGE FlexibleInstances #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE ImplicitParams #-}
-{-# LANGUAGE KindSignatures #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE MultiWayIf #-}
-{-# LANGUAGE QuasiQuotes #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeOperators #-}
-module Simplify (
- simplifyN, simplifyFix,
- SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith,
-) where
-
-import Control.Monad (ap)
-import Data.Bifunctor (first)
-import Data.Function (fix)
-import Data.Monoid (Any(..))
-
-import Debug.Trace
-
-import AST
-import AST.Count
-import AST.Pretty
-import AST.Sparse.Types
-import AST.UnMonoid (acPrjCompose)
-import Data
-import Simplify.TH
-
-
-data SimplifyConfig = SimplifyConfig
- { scLogging :: Bool
- }
-
-defaultSimplifyConfig :: SimplifyConfig
-defaultSimplifyConfig = SimplifyConfig False
-
-simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t
-simplifyN 0 = id
-simplifyN n = simplifyN (n - 1) . simplify
-
-simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t
-simplify =
- let ?accumInScope = checkAccumInScope @env knownEnv
- ?config = defaultSimplifyConfig
- in snd . runSM . simplify'
-
-simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t
-simplifyWith config =
- let ?accumInScope = checkAccumInScope @env knownEnv
- ?config = config
- in snd . runSM . simplify'
-
-simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t
-simplifyFix = simplifyFixWith defaultSimplifyConfig
-
-simplifyFixWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t
-simplifyFixWith config =
- let ?accumInScope = checkAccumInScope @env knownEnv
- ?config = config
- in fix $ \loop e ->
- let (act, e') = runSM (simplify' e)
- in if act then loop e' else e'
-
--- | simplify monad
-newtype SM tenv tt env t a = SM ((Ex env t -> Ex tenv tt) -> (Any, a))
- deriving (Functor)
-
-instance Applicative (SM tenv tt env t) where
- pure x = SM (\_ -> (Any False, x))
- (<*>) = ap
-
-instance Monad (SM tenv tt env t) where
- SM f >>= g = SM $ \ctx -> f ctx >>= \x -> let SM h = g x in h ctx
-
-runSM :: SM env t env t a -> (Bool, a)
-runSM (SM f) = first getAny (f id)
-
-smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt)
-smReconstruct core = SM (\ctx -> (Any False, ctx core))
-
-class Monad m => ActedMonad m where
- tellActed :: m ()
- hideActed :: m a -> m a
- liftActed :: (Any, a) -> m a
-
-instance ActedMonad ((,) Any) where
- tellActed = (Any True, ())
- hideActed (_, x) = (Any False, x)
- liftActed = id
-
-instance ActedMonad (SM tenv tt env t) where
- tellActed = SM (\_ -> tellActed)
- hideActed (SM f) = SM (\ctx -> hideActed (f ctx))
- liftActed pair = SM (\_ -> pair)
-
--- more convenient in practice
-acted :: ActedMonad m => m a -> m a
-acted m = tellActed >> m
-
-within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a
-within subctx (SM f) = SM $ \ctx -> f (ctx . subctx)
-
-simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t)
-simplify' expr
- | scLogging ?config = do
- res <- simplify'Rec expr
- full <- smReconstruct res
- let printed = ppExpr knownEnv full
- replace a bs = concatMap (\x -> if x == a then bs else [x])
- str | '\n' `elem` printed = "--- simplify step:\n " ++ replace '\n' "\n " printed
- | otherwise = "--- simplify step: " ++ printed
- traceM str
- return res
- | otherwise = simplify'Rec expr
-
-simplify'Rec :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t)
-simplify'Rec = \case
- -- inlining
- ELet _ rhs body
- | cheapExpr rhs
- -> acted $ simplify' (substInline rhs body)
-
- | Occ lexOcc runOcc <- occCount IZ body
- , ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One) -- without effects, normal rules apply
- || (lexOcc == One && runOcc == One) -- with effects, linear inlining is still allowed, but weakening is not
- -> acted $ simplify' (substInline rhs body)
-
- -- let splitting / let peeling
- ELet _ (EPair _ a b) body ->
- acted $ simplify' $
- ELet ext a $
- ELet ext (weakenExpr WSink b) $
- subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ)
- IS i -> EVar ext t (IS (IS i)))
- body
- ELet _ (EJust _ a) body ->
- acted $ simplify' $ ELet ext a $ subst0 (EJust ext (EVar ext (typeOf a) IZ)) body
- ELet _ (EInl _ t2 a) body ->
- acted $ simplify' $ ELet ext a $ subst0 (EInl ext t2 (EVar ext (typeOf a) IZ)) body
- ELet _ (EInr _ t1 a) body ->
- acted $ simplify' $ ELet ext a $ subst0 (EInr ext t1 (EVar ext (typeOf a) IZ)) body
-
- -- let rotation
- ELet _ (ELet _ rhs a) b -> do
- b' <- within (ELet ext (ELet ext rhs a)) $ simplify' b
- acted $ simplify' $
- ELet ext rhs $
- ELet ext a $
- weakenExpr (WCopy WSink) b'
-
- -- beta rules for products
- EFst _ (EPair _ e e')
- | not (hasAdds e') -> acted $ simplify' e
- | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e)
- ESnd _ (EPair _ e' e)
- | not (hasAdds e') -> acted $ simplify' e
- | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e)
-
- -- beta rules for coproducts
- ECase _ (EInl _ _ e) rhs _ -> acted $ simplify' (ELet ext e rhs)
- ECase _ (EInr _ _ e) _ rhs -> acted $ simplify' (ELet ext e rhs)
-
- -- beta rules for maybe
- EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1
- EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1
-
- -- let floating
- EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body))
- ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body))
- ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2)))
- EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body))
- EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e)))
- EAccum _ t p e1 sp (ELet _ rhs body) acc ->
- acted $ simplify' $
- ELet ext rhs $
- EAccum ext t p (weakenExpr WSink e1) sp body (weakenExpr WSink acc)
-
- -- let () = e in () ~> e
- ELet _ e1 (ENil _) | STNil <- typeOf e1 ->
- acted $ simplify' e1
-
- -- map (\_ -> x) e ~> build (shape e) (\_ -> x)
- EMap _ e1 e2
- | Occ Zero Zero <- occCount IZ e1
- , STArr n _ <- typeOf e2 ->
- acted $ simplify' $
- EBuild ext n (EShape ext e2) $
- subst (\_ t' -> \case IZ -> error "Unused variable was used"
- IS i -> EVar ext t' (IS i))
- e1
-
- -- vertical fusion
- EMap _ e1 (EMap _ e2 e3) ->
- acted $ simplify' $
- EMap ext (ELet ext e2 (weakenExpr (WCopy WSink) e1)) e3
-
- -- projection down-commuting
- EFst _ (ECase _ e1 e2 e3) ->
- acted $ simplify' $
- ECase ext e1 (EFst ext e2) (EFst ext e3)
- ESnd _ (ECase _ e1 e2 e3) ->
- acted $ simplify' $
- ECase ext e1 (ESnd ext e2) (ESnd ext e3)
- EFst _ (EMaybe _ e1 e2 e3) ->
- acted $ simplify' $
- EMaybe ext (EFst ext e1) (EFst ext e2) e3
- ESnd _ (EMaybe _ e1 e2 e3) ->
- acted $ simplify' $
- EMaybe ext (ESnd ext e1) (ESnd ext e2) e3
-
- -- TODO: more array indexing
- EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1), not (hasAdds e2) -> acted $ simplify' $ elet e3 e2
- EIdx _ (EMap _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet (EIdx ext e2 e3) e1
- EIdx _ (EReplicate1Inner _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ EIdx ext e2 (EFst ext e3)
- EIdx _ (EUnit _ e1) e2 | not (hasAdds e2) -> acted $ simplify' $ e1
-
- -- TODO: more array shape
- EShape _ (EBuild _ _ e1 e2) | not (hasAdds e2) -> acted $ simplify' e1
- EShape _ (EMap _ e1 e2) | not (hasAdds e1) -> acted $ simplify' (EShape ext e2)
-
- -- TODO: more constant folding
- EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext))
- EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext))
-
- -- inline cheap array constructors
- ELet _ (EReplicate1Inner _ e1 e2) e3 ->
- acted $ simplify' $
- ELet ext (EPair ext e1 e2) $
- let v = EVar ext (STPair tIx (typeOf e2)) IZ
- in subst0 (EReplicate1Inner ext (EFst ext v) (ESnd ext v)) e3
- -- -- TODO: This is a bad idea and anyway only helps in practice if (!) is
- -- -- cheap, which it can't be because (!) is not cheap if you do AD after.
- -- -- Should do proper SoA representation.
- -- ELet _ (EBuild _ n e1 e2) e3 | cheapExpr e2 ->
- -- acted $ simplify' $
- -- ELet ext e1 $
- -- subst0 (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) IZ) (weakenExpr (WCopy WSink) e2)) e3
-
- -- eta rule for unit
- e | STNil <- typeOf e, not ?accumInScope || not (hasAdds e) ->
- case e of
- ENil _ -> return e
- _ -> acted $ return (ENil ext)
-
- EBuild _ SZ _ e ->
- acted $ simplify' $ EUnit ext (substInline (ENil ext) e)
-
- -- monoid rules
- EAccum _ t p e1 sp e2 acc -> do
- e1' <- within (\e1' -> EAccum ext t p e1' sp e2 acc ) $ simplify' e1
- e2' <- within (\e2' -> EAccum ext t p e1' sp e2' acc ) $ simplify' e2
- acc' <- within (\acc' -> EAccum ext t p e1' sp e2' acc') $ simplify' acc
- simplifyOHT (OneHotTerm SAID t p e1' sp e2')
- (acted $ return (ENil ext))
- (\sp' (InContext w wrap e) -> do
- e' <- within (\e' -> wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')) $ simplify' e
- return (wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')))
- (\(InContext w wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> do
- -- The acted management here is a hideous mess.
- e1''' <- hideActed $ within (\e1''' -> wrap $ EAccum ext t' p' e1''' sp' e2'' (weakenExpr w acc')) $ simplify' e1''
- e2''' <- hideActed $ within (\e2''' -> wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')) $ simplify' e2''
- return (wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')))
- EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e
- EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e
- EOneHot _ t p e1 e2 -> do
- e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1
- e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2
- simplifyOHT (OneHotTerm SAIS t p e1' (spDense (acPrjTy p t)) e2')
- (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2)))
- (\sp' (InContext _ wrap e) ->
- case isDense t sp' of
- Just Refl -> do
- e' <- hideActed $ within wrap $ simplify' e
- return (wrap e')
- Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse")
- (\(InContext _ wrap (OneHotTerm _ t' p' e1'' sp' e2'')) ->
- case isDense (acPrjTy p' t') sp' of
- Just Refl -> do
- e1''' <- hideActed $ within (\e1''' -> wrap $ EOneHot ext t' p' e1''' e2'') $ simplify' e1''
- e2''' <- hideActed $ within (\e2''' -> wrap $ EOneHot ext t' p' e1''' e2''') $ simplify' e2''
- return (wrap $ EOneHot ext t' p' e1''' e2''')
- Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse")
-
- -- type-specific equations for plus
- EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) ->
- acted $ return (ENil ext)
-
- EPlus _ (SMTPair t1 t2) (EPair _ a1 b1) (EPair _ a2 b2) ->
- acted $ simplify' $ EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2)
-
- EPlus _ (SMTLEither t1 _) (ELInl _ dt2 a1) (ELInl _ _ a2) ->
- acted $ simplify' $ ELInl ext dt2 (EPlus ext t1 a1 a2)
- EPlus _ (SMTLEither _ t2) (ELInr _ dt1 b1) (ELInr _ _ b2) ->
- acted $ simplify' $ ELInr ext dt1 (EPlus ext t2 b1 b2)
- EPlus _ SMTLEither{} ELNil{} e -> acted $ simplify' e
- EPlus _ SMTLEither{} e ELNil{} -> acted $ simplify' e
-
- EPlus _ (SMTMaybe t) (EJust _ e1) (EJust _ e2) ->
- acted $ simplify' $ EJust ext (EPlus ext t e1 e2)
- EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e
- EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e
-
- -- fallback recursion
- EVar _ t i -> pure $ EVar ext t i
- ELet _ a b -> [simprec| ELet ext *a *b |]
- EPair _ a b -> [simprec| EPair ext *a *b |]
- EFst _ e -> [simprec| EFst ext *e |]
- ESnd _ e -> [simprec| ESnd ext *e |]
- ENil _ -> pure $ ENil ext
- EInl _ t e -> [simprec| EInl ext t *e |]
- EInr _ t e -> [simprec| EInr ext t *e |]
- ECase _ e a b -> [simprec| ECase ext *e *a *b |]
- ENothing _ t -> pure $ ENothing ext t
- EJust _ e -> [simprec| EJust ext *e |]
- EMaybe _ a b e -> [simprec| EMaybe ext *a *b *e |]
- ELNil _ t1 t2 -> pure $ ELNil ext t1 t2
- ELInl _ t e -> [simprec| ELInl ext t *e |]
- ELInr _ t e -> [simprec| ELInr ext t *e |]
- ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |]
- EConstArr _ n t v -> pure $ EConstArr ext n t v
- EBuild _ n a b -> [simprec| EBuild ext n *a *b |]
- EMap _ a b -> [simprec| EMap ext *a *b |]
- EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |]
- ESum1Inner _ e -> [simprec| ESum1Inner ext *e |]
- EUnit _ e -> [simprec| EUnit ext *e |]
- EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |]
- EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |]
- EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |]
- EReshape _ n a b -> [simprec| EReshape ext n *a *b |]
- EZip _ a b -> [simprec| EZip ext *a *b |]
- EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |]
- EFold1InnerD2 _ cm a b c -> [simprec| EFold1InnerD2 ext cm *a *b *c |]
- EConst _ t v -> pure $ EConst ext t v
- EIdx0 _ e -> [simprec| EIdx0 ext *e |]
- EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |]
- EIdx _ a b -> [simprec| EIdx ext *a *b |]
- EShape _ e -> [simprec| EShape ext *e |]
- EOp _ op e -> [simprec| EOp ext op *e |]
- ECustom _ s t p a b c e1 e2 -> do
- a' <- within (\a' -> ECustom ext s t p a' b c e1 e2) (let ?accumInScope = False in simplify' a)
- b' <- within (\b' -> ECustom ext s t p a' b' c e1 e2) (let ?accumInScope = False in simplify' b)
- c' <- within (\c' -> ECustom ext s t p a' b' c' e1 e2) (let ?accumInScope = False in simplify' c)
- e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1)
- e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2)
- pure (ECustom ext s t p a' b' c' e1' e2')
- ERecompute _ e -> [simprec| ERecompute ext *e |]
- EWith _ t e1 e2 -> do
- e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1)
- e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2)
- pure (EWith ext t e1' e2')
- -- EOneHot _ t p e1 e2 -> [simprec| EOneHot ext t p *e1 *e2 |]
- -- EAccum _ t p e1 sp e2 acc -> [simprec| EAccum ext t p *e1 sp *e2 *acc |]
- EZero _ t e -> [simprec| EZero ext t *e |]
- EDeepZero _ t e -> [simprec| EDeepZero ext t *e |]
- EPlus _ t a b -> [simprec| EPlus ext t *a *b |]
- EError _ t s -> pure $ EError ext t s
-
--- | This can be made more precise by tracking (and not counting) adds on
--- locally eliminated accumulators.
-hasAdds :: Expr x env t -> Bool
-hasAdds = \case
- EVar _ _ _ -> False
- ELet _ rhs body -> hasAdds rhs || hasAdds body
- EPair _ a b -> hasAdds a || hasAdds b
- EFst _ e -> hasAdds e
- ESnd _ e -> hasAdds e
- ENil _ -> False
- EInl _ _ e -> hasAdds e
- EInr _ _ e -> hasAdds e
- ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b
- ENothing _ _ -> False
- EJust _ e -> hasAdds e
- EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e
- ELNil _ _ _ -> False
- ELInl _ _ e -> hasAdds e
- ELInr _ _ e -> hasAdds e
- ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c
- EConstArr _ _ _ _ -> False
- EBuild _ _ a b -> hasAdds a || hasAdds b
- EMap _ a b -> hasAdds a || hasAdds b
- EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
- ESum1Inner _ e -> hasAdds e
- EUnit _ e -> hasAdds e
- EReplicate1Inner _ a b -> hasAdds a || hasAdds b
- EMaximum1Inner _ e -> hasAdds e
- EMinimum1Inner _ e -> hasAdds e
- EReshape _ _ a b -> hasAdds a || hasAdds b
- EZip _ a b -> hasAdds a || hasAdds b
- EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
- EFold1InnerD2 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
- ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
- EConst _ _ _ -> False
- EIdx0 _ e -> hasAdds e
- EIdx1 _ a b -> hasAdds a || hasAdds b
- EIdx _ a b -> hasAdds a || hasAdds b
- EShape _ e -> hasAdds e
- EOp _ _ e -> hasAdds e
- EWith _ _ a b -> hasAdds a || hasAdds b
- ERecompute _ e -> hasAdds e
- EAccum _ _ _ _ _ _ _ -> True
- EZero _ _ e -> hasAdds e
- EDeepZero _ _ e -> hasAdds e
- EPlus _ _ a b -> hasAdds a || hasAdds b
- EOneHot _ _ _ a b -> hasAdds a || hasAdds b
- EError _ _ _ -> False
-
-checkAccumInScope :: SList STy env -> Bool
-checkAccumInScope = \case SNil -> False
- SCons t env -> check t || checkAccumInScope env
- where
- check :: STy t -> Bool
- check STNil = False
- check (STPair s t) = check s || check t
- check (STEither s t) = check s || check t
- check (STLEither s t) = check s || check t
- check (STMaybe t) = check t
- check (STArr _ t) = check t
- check (STScal _) = False
- check STAccum{} = True
-
-data OneHotTerm dense env a where
- OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a
-deriving instance Show (OneHotTerm dense env a)
-
-data InContext f env (a :: Ty) where
- InContext :: env :> env' -> (forall t. Ex env' t -> Ex env t) -> f env' a -> InContext f env a
-
-simplifyOHT_recogniseMonoid :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a)
-simplifyOHT_recogniseMonoid (OneHotTerm dense t prj idx sp val) = do
- val' <- liftActed $ recogniseMonoid (applySparse sp (acPrjTy prj t)) val
- return $ OneHotTerm dense t prj idx sp val'
-
-simplifyOHT_unsparse :: ActedMonad m => OneHotTerm dense env a -> m (InContext (OneHotTerm dense) env a)
-simplifyOHT_unsparse (OneHotTerm SAID t prj1 idx1 sp1 val1) =
- unsparseOneHotD sp1 val1 $ \w wrap prj2 idx2 sp2 val2 ->
- acPrjCompose SAID prj1 (weakenExpr w idx1) prj2 idx2 $ \prj' idx' ->
- return $ InContext w wrap (OneHotTerm SAID t prj' idx' sp2 val2)
-simplifyOHT_unsparse oht@(OneHotTerm SAIS _ _ _ _ _) = return $ InContext WId id oht
-
-simplifyOHT_concat :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a)
-simplifyOHT_concat (OneHotTerm @dense @_ @_ @_ @env dense t1 prj1 idx1 sp (EOneHot @_ @c @p2 _ t2 prj2 idx2 val))
- | Just Refl <- isDense (acPrjTy prj1 t1) sp =
- let idx2' :: Ex env (AcIdx dense p2 c)
- idx2' = case dense of
- SAID -> reduceAcIdx t2 prj2 idx2
- SAIS -> idx2
- in acPrjCompose dense prj1 idx1 prj2 idx2' $ \prj' idx' ->
- acted $ return $ OneHotTerm dense t1 prj' idx' (spDense (acPrjTy prj' t1)) val
-simplifyOHT_concat oht = return oht
-
--- -- Property not expressed in types: if the Sparse in the input OneHotTerm is
--- -- dense, then the Sparse in the output will also be dense. This property is
--- -- used when simplifying EOneHot, which cannot represent sparsity.
-simplifyOHT :: ActedMonad m => OneHotTerm dense env a
- -> m r -- ^ Zero case (onehot is actually zero)
- -> (forall b. Sparse a b -> InContext Ex env b -> m r) -- ^ Trivial case (no zeros in onehot)
- -> (InContext (OneHotTerm dense) env a -> m r) -- ^ Simplified
- -> m r
-simplifyOHT oht kzero ktriv k = do
- -- traceM $ "sOHT: input " ++ show oht
- oht1 <- simplifyOHT_recogniseMonoid oht
- -- traceM $ "sOHT: recog " ++ show oht1
- InContext w1 wrap1 oht2 <- simplifyOHT_unsparse oht1
- -- traceM $ "sOHT: unspa " ++ show oht2
- oht3 <- simplifyOHT_concat oht2
- -- traceM $ "sOHT: conca " ++ show oht3
- -- traceM ""
- case oht3 of
- OneHotTerm _ _ _ _ _ EZero{} -> kzero
- OneHotTerm _ _ SAPHere _ sp val -> ktriv sp (InContext w1 wrap1 val)
- _ -> k (InContext w1 wrap1 oht3)
-
--- Sets the acted flag whenever a non-trivial projection is returned or the
--- output Sparse is different from the input Sparse.
-unsparseOneHotD :: ActedMonad m => Sparse a a' -> Ex env a'
- -> (forall p b c env'. env :> env' -> (forall s. Ex env' s -> Ex env s)
- -> SAcPrj p a b -> Ex env' (AcIdxD p a) -> Sparse b c -> Ex env' c -> m r) -> m r
-unsparseOneHotD topsp topval k = case (topsp, topval) of
- -- eliminate always-Just sparse onehot
- (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) ->
- acted $ unsparseOneHotD s (EOneHot ext t prj idx val) k
-
- -- expand the top levels of a onehot for a sparse type into a onehot for the
- -- corresponding non-sparse type
- (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) ->
- unsparseOneHotD s1 (EOneHot ext t1 prj (efst idx) val) $ \w wrap spprj idx' s1' e' ->
- acted $ k w wrap (SAPFst spprj) idx' s1' e'
- (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) ->
- unsparseOneHotD s2 (EOneHot ext t2 prj (esnd idx) val) $ \w wrap spprj idx' s1' e' ->
- acted $ k w wrap (SAPSnd spprj) idx' s1' e'
- (SpLEither s1 _, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) ->
- unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' ->
- acted $ k w wrap (SAPLeft spprj) idx' s1' e'
- (SpLEither _ s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) ->
- unsparseOneHotD s2 (EOneHot ext t2 prj idx val) $ \w wrap spprj idx' s1' e' ->
- acted $ k w wrap (SAPRight spprj) idx' s1' e'
- (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) ->
- unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' ->
- acted $ k w wrap (SAPJust spprj) idx' s1' e'
- (SpArr s1, EOneHot _ (SMTArr _ t1) (SAPArrIdx prj) idx val)
- | Dict <- styKnown (typeOf idx) ->
- unsparseOneHotD s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \w wrap spprj idx' s1' e' ->
- acted $ k (w .> WSink) (elet idx . wrap) (SAPArrIdx spprj) (EPair ext (efst (efst (evar (w @> IZ)))) idx') s1' e'
-
- -- anything else we don't know how to improve
- _ -> k WId id SAPHere (ENil ext) topsp topval
-
-{-
-unsparseOneHotS :: ActedMonad m
- => Sparse a a' -> Ex env a'
- -> (forall b. Sparse a b -> Ex env b -> m r) -> m r
-unsparseOneHotS topsp topval k = case (topsp, topval) of
- -- order is relevant to make sure we set the acted flag correctly
- (SpAbsent, v@ENil{}) -> k SpAbsent v
- (SpAbsent, v@EZero{}) -> k SpAbsent v
- (SpAbsent, _) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
- (_, EZero{}) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
- (sp, _) | isAbsent sp -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
-
- -- the unsparsifying
- (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) ->
- acted $ unsparseOneHotS s (EOneHot ext t prj idx val) k
-
- -- recursion
- -- TODO: coproducts could safely become projections as they do not need
- -- zeroinfo. But that would only work if the coproduct is at the top, because
- -- as soon as we hit a product, we need zeroinfo to make it a projection and
- -- we don't have that.
- (SpSparse s, e) -> k (SpSparse s) e
- (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) ->
- unsparseOneHotS s1 (EOneHot ext t1 prj (efst idx) val) $ \s1' e' ->
- acted $ k (SpPair s1' SpAbsent) (EPair ext e' (ENil ext))
- (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) ->
- unsparseOneHotS s2 (EOneHot ext t2 prj (esnd idx) val) $ \s2' e' ->
- acted $ k (SpPair SpAbsent s2') (EPair ext (ENil ext) e')
- (SpLEither s1 s2, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) ->
- unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> do
- case s2 of SpAbsent -> pure () ; _ -> tellActed
- k (SpLEither s1' SpAbsent) (ELInl ext STNil e')
- (SpLEither s1 s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) ->
- unsparseOneHotS s2 (EOneHot ext t2 prj idx val) $ \s2' e' -> do
- case s1 of SpAbsent -> pure () ; _ -> tellActed
- acted $ k (SpLEither SpAbsent s2') (ELInr ext STNil e')
- (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) ->
- unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' ->
- k (SpMaybe s1') (EJust ext e')
- (SpArr s1, EOneHot _ (SMTArr n t1) (SAPArrIdx prj) idx val) ->
- unsparseOneHotS s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \s1' e' ->
- k (SpArr s1') (elet idx $ EOneHot ext (SMTArr n (applySparse s1' _)) (SAPArrIdx SAPHere) (EPair ext (efst (evar IZ)) (ENil ext)) e')
- _ -> _
--}
-
--- | Recognises 'EZero' and 'EOneHot'.
-recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t)
-recogniseMonoid _ e@EOneHot{} = return e
-recogniseMonoid SMTNil (ENil _) = acted $ return $ EZero ext SMTNil (ENil ext)
-recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) =
- ((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case
- (EZero _ _ ezi1, EZero _ _ ezi2) -> acted $ return $ EZero ext typ (EPair ext ezi1 ezi2)
- (a', EZero _ _ ezi2) -> acted $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a'
- (EZero _ _ ezi1, b') -> acted $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b'
- (a', b') -> return $ EPair ext a' b'
-recogniseMonoid typ@(SMTLEither t1 t2) expr =
- case expr of
- ELNil{} -> acted $ return $ EZero ext typ (ENil ext)
- ELInl _ _ e -> acted $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e
- ELInr _ _ e -> acted $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e
- _ -> return expr
-recogniseMonoid typ@(SMTMaybe t1) expr =
- case expr of
- ENothing{} -> acted $ return $ EZero ext typ (ENil ext)
- EJust _ e -> acted $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e
- _ -> return expr
-recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) =
- acted $ do
- e' <- recogniseMonoid t e
- return $
- ELet ext e' $
- EOneHot ext typ (SAPArrIdx SAPHere)
- (EPair ext (EPair ext (ENil ext) (EUnit ext (makeZeroInfo t (EVar ext (fromSMTy t) IZ))))
- (ENil ext))
- (EVar ext (fromSMTy t) IZ)
-recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of
- (STI32, 0) -> acted $ return $ EZero ext typ (ENil ext)
- (STI64, 0) -> acted $ return $ EZero ext typ (ENil ext)
- (STF32, 0) -> acted $ return $ EZero ext typ (ENil ext)
- (STF64, 0) -> acted $ return $ EZero ext typ (ENil ext)
- _ -> return e
-recogniseMonoid _ e = return e
-
-reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdxS p a) -> Ex env (AcIdxD p a)
-reduceAcIdx topty topprj e = case (topty, topprj) of
- (_, SAPHere) -> ENil ext
- (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e)
- (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e)
- (SMTLEither t1 _ , SAPLeft p) -> reduceAcIdx t1 p e
- (SMTLEither _ t2, SAPRight p) -> reduceAcIdx t2 p e
- (SMTMaybe t1, SAPJust p) -> reduceAcIdx t1 p e
- (SMTArr _ t, SAPArrIdx p) ->
- eunPair e $ \_ e1 e2 ->
- EPair ext (efst e1) (reduceAcIdx t p e2)
-
-zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t)
-zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e)
- where
- -- invariant: AcIdx expression is duplicable
- go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t)
- go t SAPHere _ e = makeZeroInfo t e
- go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx)
- go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e)
- go SMTLEither{} _ _ _ = ENil ext
- go SMTMaybe{} _ _ _ = ENil ext
- go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx)