From 20f7d7be13cd7869b338f98d1ab3fd33e8bbfb3e Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 27 Nov 2025 21:30:17 +0100 Subject: WIP user-specified custom types The big roadblock encountered is that accumulation wants addition of monoids to be elementwise float addition; this fundamentally clashes with the concept of a user type with a custom zero and plus. --- src/CHAD/AST/Sparse.hs | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) (limited to 'src/CHAD/AST/Sparse.hs') diff --git a/src/CHAD/AST/Sparse.hs b/src/CHAD/AST/Sparse.hs index 85f2882..30e6b6f 100644 --- a/src/CHAD/AST/Sparse.hs +++ b/src/CHAD/AST/Sparse.hs @@ -1,7 +1,9 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) where @@ -9,8 +11,10 @@ module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) wh import Data.Type.Equality import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Env import CHAD.AST.Sparse.Types -import CHAD.Data (SBool(..)) +import CHAD.Data sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' @@ -43,6 +47,7 @@ sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 = (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ)))) sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2 sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 +sparsePlus (SMTUser t) SpUser e1 e2 = EPlus ext (SMTUser t) e1 e2 cheapZero :: SMTy t -> Maybe (forall env. Ex env t) @@ -61,6 +66,19 @@ cheapZero (SMTScal t) = case t of STI64 -> Just (EConst ext t 0) STF32 -> Just (EConst ext t 0.0) STF64 -> Just (EConst ext t 0.0) +cheapZero (SMTUser t) = + let zero1 = euserZero t (EVar ext (userZeroInfo t) IZ) + occenv1 = occCountAll @_ @'[_] zero1 + zero2 = euserZero t (euserZeroInfo t (EVar ext (userRepTy t) IZ)) + occenv2 = occCountAll @_ @'[_] zero2 + in deleteUnused (userZeroInfo t `SCons` SNil) occenv1 $ \case + sub@(SENo SETop) | cheapExpr zero1 -> + Just (EUser ext (STUser t) (weakenExpr WClosed (unsafeWeakenWithSubenv sub zero1))) + _ -> + deleteUnused (userRepTy t `SCons` SNil) occenv2 $ \case + sub@(SENo SETop) | cheapExpr zero2 -> + Just (EUser ext (STUser t) (weakenExpr WClosed (unsafeWeakenWithSubenv sub zero2))) + _ -> Nothing data Injection sp a b where @@ -294,3 +312,6 @@ sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = -- scalars sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t)) + +-- user types +sparsePlusS _ _ (SMTUser t) SpUser SpUser k = k SpUser (Inj id) (Inj id) (EPlus ext (SMTUser t)) -- cgit v1.2.3-70-g09d2