aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/CHAD/AST.hs58
-rw-r--r--src/CHAD/AST/Env.hs4
-rw-r--r--src/CHAD/AST/Sparse/Types.hs33
-rw-r--r--src/CHAD/AST/Types.hs59
-rw-r--r--src/CHAD/AST/UnMonoid.hs5
-rw-r--r--src/CHAD/Drev.hs3
6 files changed, 88 insertions, 74 deletions
diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs
index b795070..3f6dfc4 100644
--- a/src/CHAD/AST.hs
+++ b/src/CHAD/AST.hs
@@ -442,64 +442,6 @@ subst' f w = \case
weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))
-class KnownScalTy t where knownScalTy :: SScalTy t
-instance KnownScalTy TI32 where knownScalTy = STI32
-instance KnownScalTy TI64 where knownScalTy = STI64
-instance KnownScalTy TF32 where knownScalTy = STF32
-instance KnownScalTy TF64 where knownScalTy = STF64
-instance KnownScalTy TBool where knownScalTy = STBool
-
-class KnownTy t where knownTy :: STy t
-instance KnownTy TNil where knownTy = STNil
-instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy
-instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
-instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy
-instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy
-instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
-instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
-instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy
-
-class KnownMTy t where knownMTy :: SMTy t
-instance KnownMTy TNil where knownMTy = SMTNil
-instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy
-instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy
-instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy
-instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy
-instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy
-
-class KnownEnv env where knownEnv :: SList STy env
-instance KnownEnv '[] where knownEnv = SNil
-instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv
-
-styKnown :: STy t -> Dict (KnownTy t)
-styKnown STNil = Dict
-styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
-styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
-styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
-styKnown (STMaybe t) | Dict <- styKnown t = Dict
-styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
-styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
-styKnown (STAccum t) | Dict <- smtyKnown t = Dict
-
-smtyKnown :: SMTy t -> Dict (KnownMTy t)
-smtyKnown SMTNil = Dict
-smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
-smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
-smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict
-smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict
-smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict
-
-sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t)
-sscaltyKnown STI32 = Dict
-sscaltyKnown STI64 = Dict
-sscaltyKnown STF32 = Dict
-sscaltyKnown STF64 = Dict
-sscaltyKnown STBool = Dict
-
-envKnown :: SList STy env -> Dict (KnownEnv env)
-envKnown SNil = Dict
-envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict
-
cheapExpr :: Expr x env t -> Bool
cheapExpr = \case
EVar{} -> True
diff --git a/src/CHAD/AST/Env.hs b/src/CHAD/AST/Env.hs
index 8e6b745..40b6ca2 100644
--- a/src/CHAD/AST/Env.hs
+++ b/src/CHAD/AST/Env.hs
@@ -53,9 +53,9 @@ subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env)
subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp)
subenvOnehot SNil i _ = case i of {}
-subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3
+subenvCompose :: IsSubType s => Subenv' (:~:) env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3
subenvCompose SETop SETop = SETop
-subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2)
+subenvCompose (SEYes Refl sub1) (SEYes s2 sub2) = SEYes s2 (subenvCompose sub1 sub2)
subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)
subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2)
diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs
index 8f41ba4..930475a 100644
--- a/src/CHAD/AST/Sparse/Types.hs
+++ b/src/CHAD/AST/Sparse/Types.hs
@@ -1,14 +1,18 @@
+{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
module CHAD.AST.Sparse.Types where
import Data.Kind (Type, Constraint)
import Data.Type.Equality
import CHAD.AST.Types
+import CHAD.Data
data Sparse t t' where
@@ -19,9 +23,22 @@ data Sparse t t' where
SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b')
SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
+ SpArrIdx :: SList (Sparse t) t's -> Sparse (TArr n t) (MultiHot n t's)
SpScal :: Sparse (TScal t) (TScal t)
deriving instance Show (Sparse t t')
+type family MultiHot n t's where
+ MultiHot n '[] = TNil
+ MultiHot n (t' : t's) = TPair (TPair (Tup (Replicate n TIx)) t') (MultiHot n t's)
+
+tMultiHot :: SNat n -> SList STy ts -> STy (MultiHot n ts)
+tMultiHot _ SNil = STNil
+tMultiHot n (t `SCons` ts) = STPair (STPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts)
+
+mtMultiHot :: SNat n -> SList SMTy ts -> SMTy (MultiHot n ts)
+mtMultiHot _ SNil = SMTNil
+mtMultiHot n (t `SCons` ts) = SMTPair (SMTPair (tTup (sreplicate n tIx)) t) (tMultiHot n ts)
+
class ApplySparse f where
applySparse :: Sparse t t' -> f t -> f t'
@@ -32,6 +49,7 @@ instance ApplySparse STy where
applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2)
applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
+ applySparse (SpArrIdx ss) (STArr n t) = tMultiHot n (slistMap (`applySparse` t) ss)
applySparse SpScal t = t
instance ApplySparse SMTy where
@@ -41,34 +59,23 @@ instance ApplySparse SMTy where
applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2)
applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
+ applySparse (SpArrIdx s) (SMTArr n t) = SMTPair (mkTup SMTNil SMTPair (sreplicate n (knownMTy @TIx))) (applySparse s t)
applySparse SpScal t = t
class IsSubType s where
type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint
subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t'
- subtTrans :: s a b -> s b c -> s a c
subtFull :: IsSubTypeSubject s f => f t -> s t t
instance IsSubType (:~:) where
type IsSubTypeSubject (:~:) f = ()
subtApply = gcastWith
- subtTrans = trans
subtFull _ = Refl
instance IsSubType Sparse where
type IsSubTypeSubject Sparse f = f ~ SMTy
subtApply = applySparse
-
- subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2)
- subtTrans _ SpAbsent = SpAbsent
- subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b)
- subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b)
- subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2)
- subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
- subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
- subtTrans SpScal SpScal = SpScal
-
subtFull = spDense
spDense :: SMTy t -> Sparse t t
@@ -95,6 +102,7 @@ isDense (SMTMaybe t) (SpMaybe s)
isDense (SMTArr _ t) (SpArr s)
| Just Refl <- isDense t s = Just Refl
| otherwise = Nothing
+isDense SMTArr{} SpArrIdx{} = Nothing
isDense (SMTScal _) SpScal = Just Refl
isAbsent :: Sparse t t' -> Bool
@@ -104,4 +112,5 @@ isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2
isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
isAbsent (SpMaybe s) = isAbsent s
isAbsent (SpArr s) = isAbsent s
+isAbsent (SpArrIdx s) = isAbsent s
isAbsent SpScal = False
diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs
index f0feb55..bec2201 100644
--- a/src/CHAD/AST/Types.hs
+++ b/src/CHAD/AST/Types.hs
@@ -77,6 +77,7 @@ data SMTy t where
SMTMaybe :: SMTy a -> SMTy (TMaybe a)
SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t)
SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t)
+ SMTData :: STy a -> SMTy a -- ^ inclusion of non-monoidal information
deriving instance Show (SMTy t)
instance GCompare SMTy where
@@ -215,3 +216,61 @@ unTup unpack (_ `SCons` list) tup =
type family InvTup core env where
InvTup core '[] = core
InvTup core (t : ts) = InvTup (TPair core t) ts
+
+class KnownScalTy t where knownScalTy :: SScalTy t
+instance KnownScalTy TI32 where knownScalTy = STI32
+instance KnownScalTy TI64 where knownScalTy = STI64
+instance KnownScalTy TF32 where knownScalTy = STF32
+instance KnownScalTy TF64 where knownScalTy = STF64
+instance KnownScalTy TBool where knownScalTy = STBool
+
+class KnownTy t where knownTy :: STy t
+instance KnownTy TNil where knownTy = STNil
+instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy
+instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
+instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy
+instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy
+instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
+instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
+instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy
+
+class KnownMTy t where knownMTy :: SMTy t
+instance KnownMTy TNil where knownMTy = SMTNil
+instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy
+instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy
+instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy
+instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy
+instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy
+
+class KnownEnv env where knownEnv :: SList STy env
+instance KnownEnv '[] where knownEnv = SNil
+instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv
+
+styKnown :: STy t -> Dict (KnownTy t)
+styKnown STNil = Dict
+styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STMaybe t) | Dict <- styKnown t = Dict
+styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
+styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
+styKnown (STAccum t) | Dict <- smtyKnown t = Dict
+
+smtyKnown :: SMTy t -> Dict (KnownMTy t)
+smtyKnown SMTNil = Dict
+smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
+smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
+smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict
+smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict
+smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict
+
+sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t)
+sscaltyKnown STI32 = Dict
+sscaltyKnown STI64 = Dict
+sscaltyKnown STF32 = Dict
+sscaltyKnown STF64 = Dict
+sscaltyKnown STBool = Dict
+
+envKnown :: SList STy env -> Dict (KnownEnv env)
+envKnown SNil = Dict
+envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict
diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs
index d3cad25..1a66cdf 100644
--- a/src/CHAD/AST/UnMonoid.hs
+++ b/src/CHAD/AST/UnMonoid.hs
@@ -212,6 +212,11 @@ accumulateSparse topty topsp arg accum = case (topty, topsp) of
(EIdx ext (evar (IS IZ)) (EVar ext tn IZ))
(\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $
ENil ext
+ (SMTArr _ t, SpArrIdx s) ->
+ eunPair arg $ \w1 e1 e2 ->
+ elet (accumulateSparse t s e2
+ (\w prj idx val -> accum (w .> w1) (SAPArrIdx prj) (EPair ext (weakenExpr w e1) idx) val)) $
+ ENil ext
acPrjCompose
:: SAIDense dense
diff --git a/src/CHAD/Drev.hs b/src/CHAD/Drev.hs
index bfa964b..ee92782 100644
--- a/src/CHAD/Drev.hs
+++ b/src/CHAD/Drev.hs
@@ -1009,11 +1009,10 @@ drev des accumMap sd = \case
let smallE = unsafeWeakenWithSubenv usedSub e in
subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 ->
- let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in
Ret (collectBindings (desD1E des) subD1eUsed)
(subenvAll (desD1E usedDes))
(weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e)
- (subenvCompose subMergeUsed' sub)
+ (subenvCompose (subenvD2E subMergeUsed) sub)
(letBinds (fst (weakenBindingsE (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $
weakenExpr
(autoWeak (#d (auto1 @sd)