aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/AST/Sparse.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-27 21:30:17 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-27 21:30:17 +0100
commit20f7d7be13cd7869b338f98d1ab3fd33e8bbfb3e (patch)
treea21c90034a02cdeb7240563dbbab355e49622d0a /src/CHAD/AST/Sparse.hs
parentae634c056b500a568b2d89b7f8e225404a2c0c62 (diff)
WIP user-specified custom typesuser-types
The big roadblock encountered is that accumulation wants addition of monoids to be elementwise float addition; this fundamentally clashes with the concept of a user type with a custom zero and plus.
Diffstat (limited to 'src/CHAD/AST/Sparse.hs')
-rw-r--r--src/CHAD/AST/Sparse.hs23
1 files changed, 22 insertions, 1 deletions
diff --git a/src/CHAD/AST/Sparse.hs b/src/CHAD/AST/Sparse.hs
index 85f2882..30e6b6f 100644
--- a/src/CHAD/AST/Sparse.hs
+++ b/src/CHAD/AST/Sparse.hs
@@ -1,7 +1,9 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImpredicativeTypes #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-}
module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) where
@@ -9,8 +11,10 @@ module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) wh
import Data.Type.Equality
import CHAD.AST
+import CHAD.AST.Count
+import CHAD.AST.Env
import CHAD.AST.Sparse.Types
-import CHAD.Data (SBool(..))
+import CHAD.Data
sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t'
@@ -43,6 +47,7 @@ sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 =
(EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ))))
sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2
sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2
+sparsePlus (SMTUser t) SpUser e1 e2 = EPlus ext (SMTUser t) e1 e2
cheapZero :: SMTy t -> Maybe (forall env. Ex env t)
@@ -61,6 +66,19 @@ cheapZero (SMTScal t) = case t of
STI64 -> Just (EConst ext t 0)
STF32 -> Just (EConst ext t 0.0)
STF64 -> Just (EConst ext t 0.0)
+cheapZero (SMTUser t) =
+ let zero1 = euserZero t (EVar ext (userZeroInfo t) IZ)
+ occenv1 = occCountAll @_ @'[_] zero1
+ zero2 = euserZero t (euserZeroInfo t (EVar ext (userRepTy t) IZ))
+ occenv2 = occCountAll @_ @'[_] zero2
+ in deleteUnused (userZeroInfo t `SCons` SNil) occenv1 $ \case
+ sub@(SENo SETop) | cheapExpr zero1 ->
+ Just (EUser ext (STUser t) (weakenExpr WClosed (unsafeWeakenWithSubenv sub zero1)))
+ _ ->
+ deleteUnused (userRepTy t `SCons` SNil) occenv2 $ \case
+ sub@(SENo SETop) | cheapExpr zero2 ->
+ Just (EUser ext (STUser t) (weakenExpr WClosed (unsafeWeakenWithSubenv sub zero2)))
+ _ -> Nothing
data Injection sp a b where
@@ -294,3 +312,6 @@ sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k =
-- scalars
sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t))
+
+-- user types
+sparsePlusS _ _ (SMTUser t) SpUser SpUser k = k SpUser (Inj id) (Inj id) (EPlus ext (SMTUser t))