diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-16 10:07:46 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-16 10:07:46 +0100 |
commit | 8ca9ceef96afffdc9d4bc266c978a6b4374131e6 (patch) | |
tree | 26cf63e6ab7799d3ac6f3bbfb7bbce8d982c3e84 | |
parent | 6da98aedf2f28ec8848d1cb8f5605b0c7e64d644 (diff) |
simplifyOneHotTerm
-rw-r--r-- | src/Simplify.hs | 46 |
1 files changed, 32 insertions, 14 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs index 0aa7a66..ac1bb8b 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -1,10 +1,11 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} module Simplify ( @@ -18,6 +19,7 @@ import Data.Type.Equality (testEquality) import AST import AST.Count +import CHAD.Types import Data @@ -120,19 +122,19 @@ simplify' = \case -- TODO: constant folding for operations -- monoid rules - EAccum _ t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val) acc - | Just Refl <- testEquality (acPrjTy prj1 t1) t2 - -> concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> - acted $ simplify' (EAccum ext t1 prj12 idx12 val acc) - EAccum _ _ _ _ (EZero _ _) _ -> (Any True, ENil ext) + EAccum _ t p e1 e2 acc -> do + acc' <- simplify' acc + simplifyOneHotTerm (OneHotTerm t p e1 e2) + (Any True, ENil ext) + (\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc')) + (\(OneHotTerm t' p' e1' e2') -> return (EAccum ext t' p' e1' e2' acc')) EPlus _ _ (EZero _ _) e -> acted $ simplify' e EPlus _ _ e (EZero _ _) -> acted $ simplify' e - EOneHot _ t _ _ (EZero _ _) -> (Any True, EZero ext t) - EOneHot _ _ SAPHere _ e -> acted $ simplify' e - EOneHot _ t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val) - | Just Refl <- testEquality (acPrjTy prj1 t1) t2 - -> concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> - acted $ simplify' (EOneHot ext t1 prj12 idx12 val) + EOneHot _ t p e1 e2 -> + simplifyOneHotTerm (OneHotTerm t p e1 e2) + (Any True, EZero ext t) + (\e -> (Any True, e)) + (\(OneHotTerm t' p' e1' e2') -> return (EOneHot ext t' p' e1' e2')) -- type-specific equations for plus EPlus _ STNil _ _ -> (Any True, ENil ext) @@ -188,10 +190,8 @@ simplify' = \case <*> (let ?accumInScope = False in simplify' c) <*> simplify' e1 <*> simplify' e2 EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) - EAccum _ t i e1 e2 e3 -> EAccum ext t i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3 EZero _ t -> pure $ EZero ext t EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b - EOneHot _ t i a b -> EOneHot ext t i <$> simplify' a <*> simplify' b EError _ t s -> pure $ EError ext t s acted :: (Any, a) -> (Any, a) @@ -257,6 +257,24 @@ checkAccumInScope = \case SNil -> False check (STScal _) = False check STAccum{} = True +data OneHotTerm env p a b where + OneHotTerm :: STy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env (D2 b) -> OneHotTerm env p a b +deriving instance Show (OneHotTerm env p a b) + +simplifyOneHotTerm :: OneHotTerm env p a b + -> (Any, r) -- ^ Zero case (onehot is actually zero) + -> (Ex env (D2 a) -> (Any, r)) -- ^ Trivial case (no zeros in onehot) + -> (forall p' b'. OneHotTerm env p' a b' -> (Any, r)) + -> (Any, r) +simplifyOneHotTerm (OneHotTerm _ _ _ (EZero _ _)) kzero _ _ = kzero +simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k + | Just Refl <- testEquality (acPrjTy prj1 t1) t2 + = do (Any True, ()) -- record, whatever happens later, that we've modified something + concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> + simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k +simplifyOneHotTerm (OneHotTerm _ SAPHere _ e) _ ktriv _ = ktriv e +simplifyOneHotTerm term _ _ k = k term + concatOneHots :: STy a -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) -> SAcPrj p2 b c -> Ex env (AcIdx p2 b) |