summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-16 10:07:46 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-16 10:07:46 +0100
commit8ca9ceef96afffdc9d4bc266c978a6b4374131e6 (patch)
tree26cf63e6ab7799d3ac6f3bbfb7bbce8d982c3e84
parent6da98aedf2f28ec8848d1cb8f5605b0c7e64d644 (diff)
simplifyOneHotTerm
-rw-r--r--src/Simplify.hs46
1 files changed, 32 insertions, 14 deletions
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 0aa7a66..ac1bb8b 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -1,10 +1,11 @@
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module Simplify (
@@ -18,6 +19,7 @@ import Data.Type.Equality (testEquality)
import AST
import AST.Count
+import CHAD.Types
import Data
@@ -120,19 +122,19 @@ simplify' = \case
-- TODO: constant folding for operations
-- monoid rules
- EAccum _ t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val) acc
- | Just Refl <- testEquality (acPrjTy prj1 t1) t2
- -> concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
- acted $ simplify' (EAccum ext t1 prj12 idx12 val acc)
- EAccum _ _ _ _ (EZero _ _) _ -> (Any True, ENil ext)
+ EAccum _ t p e1 e2 acc -> do
+ acc' <- simplify' acc
+ simplifyOneHotTerm (OneHotTerm t p e1 e2)
+ (Any True, ENil ext)
+ (\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc'))
+ (\(OneHotTerm t' p' e1' e2') -> return (EAccum ext t' p' e1' e2' acc'))
EPlus _ _ (EZero _ _) e -> acted $ simplify' e
EPlus _ _ e (EZero _ _) -> acted $ simplify' e
- EOneHot _ t _ _ (EZero _ _) -> (Any True, EZero ext t)
- EOneHot _ _ SAPHere _ e -> acted $ simplify' e
- EOneHot _ t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)
- | Just Refl <- testEquality (acPrjTy prj1 t1) t2
- -> concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
- acted $ simplify' (EOneHot ext t1 prj12 idx12 val)
+ EOneHot _ t p e1 e2 ->
+ simplifyOneHotTerm (OneHotTerm t p e1 e2)
+ (Any True, EZero ext t)
+ (\e -> (Any True, e))
+ (\(OneHotTerm t' p' e1' e2') -> return (EOneHot ext t' p' e1' e2'))
-- type-specific equations for plus
EPlus _ STNil _ _ -> (Any True, ENil ext)
@@ -188,10 +190,8 @@ simplify' = \case
<*> (let ?accumInScope = False in simplify' c)
<*> simplify' e1 <*> simplify' e2
EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
- EAccum _ t i e1 e2 e3 -> EAccum ext t i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3
EZero _ t -> pure $ EZero ext t
EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b
- EOneHot _ t i a b -> EOneHot ext t i <$> simplify' a <*> simplify' b
EError _ t s -> pure $ EError ext t s
acted :: (Any, a) -> (Any, a)
@@ -257,6 +257,24 @@ checkAccumInScope = \case SNil -> False
check (STScal _) = False
check STAccum{} = True
+data OneHotTerm env p a b where
+ OneHotTerm :: STy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env (D2 b) -> OneHotTerm env p a b
+deriving instance Show (OneHotTerm env p a b)
+
+simplifyOneHotTerm :: OneHotTerm env p a b
+ -> (Any, r) -- ^ Zero case (onehot is actually zero)
+ -> (Ex env (D2 a) -> (Any, r)) -- ^ Trivial case (no zeros in onehot)
+ -> (forall p' b'. OneHotTerm env p' a b' -> (Any, r))
+ -> (Any, r)
+simplifyOneHotTerm (OneHotTerm _ _ _ (EZero _ _)) kzero _ _ = kzero
+simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k
+ | Just Refl <- testEquality (acPrjTy prj1 t1) t2
+ = do (Any True, ()) -- record, whatever happens later, that we've modified something
+ concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
+ simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k
+simplifyOneHotTerm (OneHotTerm _ SAPHere _ e) _ ktriv _ = ktriv e
+simplifyOneHotTerm term _ _ k = k term
+
concatOneHots :: STy a
-> SAcPrj p1 a b -> Ex env (AcIdx p1 a)
-> SAcPrj p2 b c -> Ex env (AcIdx p2 b)