aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/AST
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/AST')
-rw-r--r--src/CHAD/AST/Accum.hs34
-rw-r--r--src/CHAD/AST/Count.hs10
-rw-r--r--src/CHAD/AST/Env.hs2
-rw-r--r--src/CHAD/AST/Pretty.hs12
-rw-r--r--src/CHAD/AST/Sparse.hs23
-rw-r--r--src/CHAD/AST/Sparse/Types.hs7
-rw-r--r--src/CHAD/AST/SplitLets.hs4
-rw-r--r--src/CHAD/AST/Types.hs102
-rw-r--r--src/CHAD/AST/Types/Ty.hs20
-rw-r--r--src/CHAD/AST/UnMonoid.hs9
-rw-r--r--src/CHAD/AST/UnUser.hs102
11 files changed, 272 insertions, 53 deletions
diff --git a/src/CHAD/AST/Accum.hs b/src/CHAD/AST/Accum.hs
index ea74a95..f61f00f 100644
--- a/src/CHAD/AST/Accum.hs
+++ b/src/CHAD/AST/Accum.hs
@@ -76,40 +76,6 @@ acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t
acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t
acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t
-type family ZeroInfo t where
- ZeroInfo TNil = TNil
- ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b)
- ZeroInfo (TLEither a b) = TNil
- ZeroInfo (TMaybe a) = TNil
- ZeroInfo (TArr n t) = TArr n (ZeroInfo t)
- ZeroInfo (TScal t) = TNil
-
-tZeroInfo :: SMTy t -> STy (ZeroInfo t)
-tZeroInfo SMTNil = STNil
-tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b)
-tZeroInfo (SMTLEither _ _) = STNil
-tZeroInfo (SMTMaybe _) = STNil
-tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
-tZeroInfo (SMTScal _) = STNil
-
--- | Info needed to create a zero-valued deep accumulator for a monoid type.
--- Should be constructable from a D1.
-type family DeepZeroInfo t where
- DeepZeroInfo TNil = TNil
- DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b)
- DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b)
- DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a)
- DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a)
- DeepZeroInfo (TScal t) = TNil
-
-tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t)
-tDeepZeroInfo SMTNil = STNil
-tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b)
-tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b)
-tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a)
-tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t)
-tDeepZeroInfo (SMTScal _) = STNil
-
-- -- | Additional info needed for accumulation. This is empty unless there is
-- -- sparsity in the monoid.
-- type family AccumInfo t where
diff --git a/src/CHAD/AST/Count.hs b/src/CHAD/AST/Count.hs
index 46173d2..8923e13 100644
--- a/src/CHAD/AST/Count.hs
+++ b/src/CHAD/AST/Count.hs
@@ -880,6 +880,16 @@ occCountX initialS topexpr k = case topexpr of
EError _ t msg ->
k OccEnd $ \_ -> EError ext (applySubstruc s t) msg
+
+ EUser _ t e ->
+ occCountX SsFull e $ \env1 mke ->
+ k env1 $ \env' ->
+ projectSmallerSubstruc SsFull s $ EUser ext t (mke env')
+
+ EUnUser _ e ->
+ occCountX SsFull e $ \env1 mke ->
+ k env1 $ \env' ->
+ projectSmallerSubstruc SsFull s $ EUnUser ext (mke env')
where
s = simplifySubstruc (typeOf topexpr) initialS
diff --git a/src/CHAD/AST/Env.hs b/src/CHAD/AST/Env.hs
index 8e6b745..73f2dcc 100644
--- a/src/CHAD/AST/Env.hs
+++ b/src/CHAD/AST/Env.hs
@@ -11,7 +11,7 @@ module CHAD.AST.Env where
import Data.Type.Equality
-import CHAD.AST.Sparse
+import CHAD.AST.Sparse.Types
import CHAD.AST.Weaken
import CHAD.Data
import CHAD.Drev.Types
diff --git a/src/CHAD/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs
index 9ddcb35..d9ac8b2 100644
--- a/src/CHAD/AST/Pretty.hs
+++ b/src/CHAD/AST/Pretty.hs
@@ -374,6 +374,15 @@ ppExpr' d val expr = case expr of
EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s)
+ EUser _ t@STUser{} e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $
+ ppApp (ppString ("user[" ++ show (typeOfProxy t) ++ "]") <> ppX expr) [e']
+
+ EUnUser _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppApp (ppString "unuser" <> ppX expr) [e']
+
ppExprLet :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc
ppExprLet d val etop = do
let collect :: PrettyX x => SVal env -> Expr x env t -> M ([(String, Occ, ADoc)], ADoc)
@@ -421,6 +430,7 @@ ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++
ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s
ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s
ppSparse (SMTScal _) SpScal = "."
+ppSparse (SMTUser _) SpUser = "U"
ppCommut :: Commutative -> String
ppCommut Commut = "(C)"
@@ -469,6 +479,7 @@ ppSTy' _ (STScal sty) = ppString $ case sty of
STF64 -> "f64"
STBool -> "bool"
ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t
+ppSTy' d (STUser t) = ppParen (d > 10) $ ppString ("User " ++ showsPrec 11 (typeOfProxy t) "")
ppSMTy :: Int -> SMTy t -> String
ppSMTy d ty = render $ ppSMTy' d ty
@@ -485,6 +496,7 @@ ppSMTy' _ (SMTScal sty) = ppString $ case sty of
STI64 -> "i64"
STF32 -> "f32"
STF64 -> "f64"
+ppSMTy' d (SMTUser t) = ppSTy' d (STUser t)
ppString :: String -> Doc x
ppString = fromString
diff --git a/src/CHAD/AST/Sparse.hs b/src/CHAD/AST/Sparse.hs
index 85f2882..30e6b6f 100644
--- a/src/CHAD/AST/Sparse.hs
+++ b/src/CHAD/AST/Sparse.hs
@@ -1,7 +1,9 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImpredicativeTypes #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-}
module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) where
@@ -9,8 +11,10 @@ module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) wh
import Data.Type.Equality
import CHAD.AST
+import CHAD.AST.Count
+import CHAD.AST.Env
import CHAD.AST.Sparse.Types
-import CHAD.Data (SBool(..))
+import CHAD.Data
sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t'
@@ -43,6 +47,7 @@ sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 =
(EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ))))
sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2
sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2
+sparsePlus (SMTUser t) SpUser e1 e2 = EPlus ext (SMTUser t) e1 e2
cheapZero :: SMTy t -> Maybe (forall env. Ex env t)
@@ -61,6 +66,19 @@ cheapZero (SMTScal t) = case t of
STI64 -> Just (EConst ext t 0)
STF32 -> Just (EConst ext t 0.0)
STF64 -> Just (EConst ext t 0.0)
+cheapZero (SMTUser t) =
+ let zero1 = euserZero t (EVar ext (userZeroInfo t) IZ)
+ occenv1 = occCountAll @_ @'[_] zero1
+ zero2 = euserZero t (euserZeroInfo t (EVar ext (userRepTy t) IZ))
+ occenv2 = occCountAll @_ @'[_] zero2
+ in deleteUnused (userZeroInfo t `SCons` SNil) occenv1 $ \case
+ sub@(SENo SETop) | cheapExpr zero1 ->
+ Just (EUser ext (STUser t) (weakenExpr WClosed (unsafeWeakenWithSubenv sub zero1)))
+ _ ->
+ deleteUnused (userRepTy t `SCons` SNil) occenv2 $ \case
+ sub@(SENo SETop) | cheapExpr zero2 ->
+ Just (EUser ext (STUser t) (weakenExpr WClosed (unsafeWeakenWithSubenv sub zero2)))
+ _ -> Nothing
data Injection sp a b where
@@ -294,3 +312,6 @@ sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k =
-- scalars
sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t))
+
+-- user types
+sparsePlusS _ _ (SMTUser t) SpUser SpUser k = k SpUser (Inj id) (Inj id) (EPlus ext (SMTUser t))
diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs
index 8f41ba4..f97a261 100644
--- a/src/CHAD/AST/Sparse/Types.hs
+++ b/src/CHAD/AST/Sparse/Types.hs
@@ -20,6 +20,7 @@ data Sparse t t' where
SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
SpScal :: Sparse (TScal t) (TScal t)
+ SpUser :: Sparse (TUser t) (TUser t)
deriving instance Show (Sparse t t')
class ApplySparse f where
@@ -33,6 +34,7 @@ instance ApplySparse STy where
applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
applySparse SpScal t = t
+ applySparse SpUser t = t
instance ApplySparse SMTy where
applySparse (SpSparse s) t = SMTMaybe (applySparse s t)
@@ -42,6 +44,7 @@ instance ApplySparse SMTy where
applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
applySparse SpScal t = t
+ applySparse SpUser t = t
class IsSubType s where
@@ -68,6 +71,7 @@ instance IsSubType Sparse where
subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
subtTrans SpScal SpScal = SpScal
+ subtTrans SpUser SpUser = SpUser
subtFull = spDense
@@ -78,6 +82,7 @@ spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2)
spDense (SMTMaybe t) = SpMaybe (spDense t)
spDense (SMTArr _ t) = SpArr (spDense t)
spDense (SMTScal _) = SpScal
+spDense (SMTUser _) = SpUser
isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t')
isDense SMTNil SpAbsent = Just Refl
@@ -96,6 +101,7 @@ isDense (SMTArr _ t) (SpArr s)
| Just Refl <- isDense t s = Just Refl
| otherwise = Nothing
isDense (SMTScal _) SpScal = Just Refl
+isDense (SMTUser _) SpUser = Just Refl
isAbsent :: Sparse t t' -> Bool
isAbsent (SpSparse s) = isAbsent s
@@ -105,3 +111,4 @@ isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
isAbsent (SpMaybe s) = isAbsent s
isAbsent (SpArr s) = isAbsent s
isAbsent SpScal = False
+isAbsent SpUser = False
diff --git a/src/CHAD/AST/SplitLets.hs b/src/CHAD/AST/SplitLets.hs
index 34267e4..75c70ea 100644
--- a/src/CHAD/AST/SplitLets.hs
+++ b/src/CHAD/AST/SplitLets.hs
@@ -79,6 +79,8 @@ splitLets' = \sub -> \case
EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b)
EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b)
EError x t s -> EError x t s
+ EUser x t e -> EUser x t (splitLets' sub e)
+ EUnUser x e -> EUnUser x (splitLets' sub e)
where
sinkF :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
-> STy t -> Idx (b : env) t -> (b : env') :> env3 -> Ex env3 t
@@ -167,6 +169,7 @@ split typ = case typ of
STArr{} -> other
STScal{} -> other
STAccum{} -> other
+ STUser{} -> other
where
other :: (Pointers (t : env) t, Bindings Ex (t : env) '[])
other = (Point typ IZ, BTop)
@@ -186,6 +189,7 @@ splitRec rhs typ = case typ of
STArr{} -> other
STScal{} -> other
STAccum{} -> other
+ STUser{} -> other
where
other :: (Pointers (t : env) t, Bindings Ex env '[t])
other = (Point typ IZ, BPush BTop (typ, rhs))
diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs
index f0feb55..bb2fcfa 100644
--- a/src/CHAD/AST/Types.hs
+++ b/src/CHAD/AST/Types.hs
@@ -1,35 +1,34 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-module CHAD.AST.Types where
+{-# LANGUAGE UndecidableSuperClasses #-}
+module CHAD.AST.Types (
+ module CHAD.AST.Types.Ty,
+ module CHAD.AST.Types,
+) where
import Data.Int (Int32, Int64)
import Data.GADT.Compare
import Data.GADT.Show
import Data.Kind (Type)
+import Data.Proxy
import Data.Type.Equality
+import Type.Reflection
import CHAD.Data
+import CHAD.AST.Types.Ty
+import {-# SOURCE #-} CHAD.AST
-type data Ty
- = TNil
- | TPair Ty Ty
- | TEither Ty Ty
- | TLEither Ty Ty
- | TMaybe Ty
- | TArr Nat Ty -- ^ rank, element type
- | TScal ScalTy
- | TAccum Ty -- ^ contained type must be a monoid type
-
-type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
-- | Scalar types happen to be bundled in 'SScalTy' as this is sometimes
-- convenient, but such scalar types are not special in any way.
@@ -43,6 +42,10 @@ data STy t where
STArr :: SNat n -> STy t -> STy (TArr n t)
STScal :: SScalTy t -> STy (TScal t)
STAccum :: SMTy t -> STy (TAccum t)
+ -- the Proxy is here just to provide something that's t-indexed, which is
+ -- damn useful; the UNPACK is to make sure it doesn't actually result in any
+ -- storage here
+ STUser :: UserPrimalType t => {-# UNPACK #-} !(Proxy t) -> STy (TUser t)
deriving instance Show (STy t)
instance GCompare STy where
@@ -62,7 +65,9 @@ instance GCompare STy where
(STScal t) (STScal t') -> gorderingLift1 (gcompare t t')
STScal{} _ -> GLT ; _ STScal{} -> GGT
(STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t')
- -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT
+ STAccum{} _ -> GLT ; _ STAccum{} -> GGT
+ (STUser t) (STUser t') -> gorderingLift1 (gcompare (typeOfProxy t) (typeOfProxy t'))
+ -- STUser{} _ -> GLT ; _ STUser{} -> GGT
instance TestEquality STy where testEquality = geq
instance GEq STy where geq = defaultGeq
@@ -77,6 +82,7 @@ data SMTy t where
SMTMaybe :: SMTy a -> SMTy (TMaybe a)
SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t)
SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t)
+ SMTUser :: UserMonoidType t => {-# UNPACK #-} !(Proxy t) -> SMTy (TUser t)
deriving instance Show (SMTy t)
instance GCompare SMTy where
@@ -92,7 +98,9 @@ instance GCompare SMTy where
(SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t')
SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT
(SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t')
- -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT
+ SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT
+ (SMTUser t) (SMTUser t') -> gorderingLift1 (gcompare (typeOfProxy t) (typeOfProxy t'))
+ -- SMTUser{} _ -> GLT ; _ SMTUser{} -> GGT
instance TestEquality SMTy where testEquality = geq
instance GEq SMTy where geq = defaultGeq
@@ -106,6 +114,7 @@ fromSMTy = \case
SMTMaybe t -> STMaybe (fromSMTy t)
SMTArr n t -> STArr n (fromSMTy t)
SMTScal sty -> STScal sty
+ SMTUser t -> STUser t
data SScalTy t where
STI32 :: SScalTy TI32
@@ -172,7 +181,65 @@ type family ScalIsIntegral t where
ScalIsIntegral TF64 = False
ScalIsIntegral TBool = False
--- | Returns true for arrays /and/ accumulators.
+type family ZeroInfo t where
+ ZeroInfo TNil = TNil
+ ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b)
+ ZeroInfo (TLEither a b) = TNil
+ ZeroInfo (TMaybe a) = TNil
+ ZeroInfo (TArr n t) = TArr n (ZeroInfo t)
+ ZeroInfo (TScal t) = TNil
+ ZeroInfo (TUser t) = UserZeroInfo t
+
+tZeroInfo :: SMTy t -> STy (ZeroInfo t)
+tZeroInfo SMTNil = STNil
+tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b)
+tZeroInfo (SMTLEither _ _) = STNil
+tZeroInfo (SMTMaybe _) = STNil
+tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
+tZeroInfo (SMTScal _) = STNil
+tZeroInfo (SMTUser t) = userZeroInfo t
+
+-- | Info needed to create a zero-valued deep accumulator for a monoid type.
+-- Constructable from a D1; must not have any dynamic sparsity, i.e. must have
+-- the same structure as the corresponding primal.
+type family DeepZeroInfo t where
+ DeepZeroInfo TNil = TNil
+ DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a)
+ DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a)
+ DeepZeroInfo (TScal t) = TNil
+ DeepZeroInfo (TUser t) = UserDeepZeroInfo t
+
+tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t)
+tDeepZeroInfo SMTNil = STNil
+tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a)
+tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t)
+tDeepZeroInfo (SMTScal _) = STNil
+tDeepZeroInfo (SMTUser t) = userDeepZeroInfo t
+
+class (Show t, Typeable t, UserMonoidType (UserD2 t)) => UserPrimalType t where
+ type UserRep t :: Ty
+ type UserD2 t :: Type
+ userRepTy :: proxy t -> STy (UserRep t)
+ euserD2ZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (ZeroInfo (TUser (UserD2 t)))
+ euserD2DeepZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (DeepZeroInfo (TUser (UserD2 t)))
+
+class UserPrimalType t => UserMonoidType t where
+ type UserZeroInfo t :: Ty
+ type UserDeepZeroInfo t :: Ty
+ userZeroInfo :: proxy t -> STy (UserZeroInfo t)
+ userDeepZeroInfo :: proxy t -> STy (UserDeepZeroInfo t)
+ euserZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (UserZeroInfo t)
+ euserDeepZeroInfo :: proxy t -> Ex env (UserRep t) -> Ex env (UserDeepZeroInfo t)
+ euserZero :: proxy t -> Ex env (UserZeroInfo t) -> Ex env (UserRep t)
+ -- | A deep zero must not have any dynamic sparsity.
+ euserDeepZero :: proxy t -> Ex env (UserDeepZeroInfo t) -> Ex env (UserRep t)
+ euserPlus :: proxy t -> Ex env (UserRep t) -> Ex env (UserRep t) -> Ex env (UserRep t)
+
+-- | Returns true for arrays /and/ accumulators. Returns False for user types.
typeHasArrays :: STy t' -> Bool
typeHasArrays STNil = False
typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b
@@ -182,6 +249,7 @@ typeHasArrays (STMaybe t) = typeHasArrays t
typeHasArrays STArr{} = True
typeHasArrays STScal{} = False
typeHasArrays STAccum{} = True
+typeHasArrays STUser{} = False
typeHasAccums :: STy t' -> Bool
typeHasAccums STNil = False
@@ -192,6 +260,7 @@ typeHasAccums (STMaybe t) = typeHasAccums t
typeHasAccums STArr{} = False
typeHasAccums STScal{} = False
typeHasAccums STAccum{} = True
+typeHasAccums STUser{} = False
type family Tup env where
Tup '[] = TNil
@@ -215,3 +284,6 @@ unTup unpack (_ `SCons` list) tup =
type family InvTup core env where
InvTup core '[] = core
InvTup core (t : ts) = InvTup (TPair core t) ts
+
+typeOfProxy :: Typeable a => proxy a -> TypeRep a
+typeOfProxy _ = typeRep
diff --git a/src/CHAD/AST/Types/Ty.hs b/src/CHAD/AST/Types/Ty.hs
new file mode 100644
index 0000000..cee03be
--- /dev/null
+++ b/src/CHAD/AST/Types/Ty.hs
@@ -0,0 +1,20 @@
+{-# LANGUAGE TypeData #-}
+module CHAD.AST.Types.Ty where
+
+import Data.Kind (Type)
+
+import CHAD.Data (Nat)
+
+
+type data Ty
+ = TNil
+ | TPair Ty Ty
+ | TEither Ty Ty
+ | TLEither Ty Ty
+ | TMaybe Ty
+ | TArr Nat Ty -- ^ rank, element type
+ | TScal ScalTy
+ | TAccum Ty -- ^ contained type must be a monoid type
+ | TUser Type
+
+type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs
index d3cad25..2166fc6 100644
--- a/src/CHAD/AST/UnMonoid.hs
+++ b/src/CHAD/AST/UnMonoid.hs
@@ -63,6 +63,8 @@ unMonoid = \case
acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc))
EError _ t s -> EError ext t s
+ EUser _ t e -> EUser ext t (unMonoid e)
+ EUnUser _ e -> EUnUser ext (unMonoid e)
zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
-- don't destroy the effects!
@@ -78,6 +80,7 @@ zero (SMTScal t) _ = case t of
STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
+zero (SMTUser t) e = EUser ext (STUser t) (euserZero t e)
deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t
deepZero SMTNil e = elet e $ ENil ext
@@ -99,6 +102,7 @@ deepZero (SMTScal t) _ = case t of
STI64 -> EConst ext STI64 0
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
+deepZero (SMTUser t) e = EUser ext (STUser t) (euserDeepZero t e)
plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
-- don't destroy the effects!
@@ -136,6 +140,7 @@ plus (SMTArr _ t) a b =
ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ))
a b
plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
+plus (SMTUser t) a b = EUser ext (STUser t) (euserPlus t (EUnUser ext a) (EUnUser ext b))
onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t
onehot typ topprj idx arg = case (typ, topprj) of
@@ -183,8 +188,8 @@ accumulateSparse
accumulateSparse topty topsp arg accum = case (topty, topsp) of
(_, s) | Just Refl <- isDense topty s ->
accum WId SAPHere (ENil ext) arg
- (SMTScal _, SpScal) ->
- accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh
+ (SMTScal _, SpScal) -> error "TScal is dense"
+ (SMTUser _, SpUser) -> error "TUser is dense"
(_, SpSparse s) ->
emaybe arg
(ENil ext)
diff --git a/src/CHAD/AST/UnUser.hs b/src/CHAD/AST/UnUser.hs
new file mode 100644
index 0000000..73b216c
--- /dev/null
+++ b/src/CHAD/AST/UnUser.hs
@@ -0,0 +1,102 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+module CHAD.AST.UnUser where
+
+import CHAD.AST
+
+
+type family UnUser t where
+ UnUser TNil = TNil
+ UnUser (TPair a b) = TPair (UnUser a) (UnUser b)
+ UnUser (TEither a b) = TEither (UnUser a) (UnUser b)
+ UnUser (TLEither a b) = TLEither (UnUser a) (UnUser b)
+ UnUser (TMaybe a) = TMaybe (UnUser a)
+ UnUser (TArr n t) = TArr n (UnUser t)
+ UnUser (TScal t) = TScal t
+ UnUser (TAccum t) = TAccum (UnUser t)
+ UnUser (TUser t) = UnUser (UserRep t)
+
+type family UnUserE env where
+ UnUserE '[] = '[]
+ UnUserE (t : ts) = UnUser t : UnUserE ts
+
+unUserTy :: STy t -> STy (UnUser t)
+unUserTy = \case
+ STNil -> STNil
+ STPair a b -> STPair (unUserTy a) (unUserTy b)
+ STEither a b -> STEither (unUserTy a) (unUserTy b)
+ STLEither a b -> STLEither (unUserTy a) (unUserTy b)
+ STMaybe t -> STMaybe (unUserTy t)
+ STArr n t -> STArr n (unUserTy t)
+ STScal t -> STScal t
+ STAccum t -> STAccum (unUserMTy t)
+ STUser t -> unUserTy (userRepTy t)
+
+unUserMTy :: SMTy t -> SMTy (UnUser t)
+unUserMTy = \case
+ SMTNil -> SMTNil
+ SMTPair a b -> SMTPair (unUserMTy a) (unUserMTy b)
+ SMTLEither a b -> SMTLEither (unUserMTy a) (unUserMTy b)
+ SMTMaybe t -> SMTMaybe (unUserMTy t)
+ SMTArr n t -> SMTArr n (unUserMTy t)
+ SMTScal t -> SMTScal t
+ SMTUser t -> unUserMTy (userRepTy t)
+
+unUser :: Ex env t -> Ex (UnUserE env) (UnUser t)
+unUser = \case
+ EUser _ _ e -> unUser e
+ EUnUser _ e -> unUser e
+
+ EVar _ t i -> EVar ext t (goIdx i)
+ ELet _ rhs body -> ELet ext (unUser rhs) (unUser body)
+ EPair _ a b -> EPair ext (unUser a) (unUser b)
+ EFst _ e -> EFst ext (unUser e)
+ ESnd _ e -> ESnd ext (unUser e)
+ ENil _ -> ENil ext
+ EInl _ t e -> EInl ext (unUserTy t) (unUser e)
+ EInr _ t e -> EInr ext (unUserTy t) (unUser e)
+ ECase _ e a b -> ECase ext (unUser e) (unUser a) (unUser b)
+ ENothing _ t -> ENothing ext t
+ EJust _ e -> EJust ext (unUser e)
+ EMaybe _ a b e -> EMaybe ext (unUser a) (unUser b) (unUser e)
+ ELNil _ t1 t2 -> ELNil ext t1 t2
+ ELInl _ t e -> ELInl ext t (unUser e)
+ ELInr _ t e -> ELInr ext t (unUser e)
+ ELCase _ e a b c -> ELCase ext (unUser e) (unUser a) (unUser b) (unUser c)
+ EConstArr _ n t x -> EConstArr ext n t x
+ EBuild _ n a b -> EBuild ext n (unUser a) (unUser b)
+ EMap _ a b -> EMap ext (unUser a) (unUser b)
+ EFold1Inner _ cm a b c -> EFold1Inner ext cm (unUser a) (unUser b) (unUser c)
+ ESum1Inner _ e -> ESum1Inner ext (unUser e)
+ EUnit _ e -> EUnit ext (unUser e)
+ EReplicate1Inner _ a b -> EReplicate1Inner ext (unUser a) (unUser b)
+ EMaximum1Inner _ e -> EMaximum1Inner ext (unUser e)
+ EMinimum1Inner _ e -> EMinimum1Inner ext (unUser e)
+ EReshape _ n a b -> EReshape ext n (unUser a) (unUser b)
+ EZip _ a b -> EZip ext (unUser a) (unUser b)
+ EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unUser a) (unUser b) (unUser c)
+ EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unUser a) (unUser b) (unUser c)
+ EConst _ t x -> EConst ext t x
+ EIdx0 _ e -> EIdx0 ext (unUser e)
+ EIdx1 _ a b -> EIdx1 ext (unUser a) (unUser b)
+ EIdx _ a b -> EIdx ext (unUser a) (unUser b)
+ EShape _ e -> EShape ext (unUser e)
+ EOp _ op e -> EOp ext op (unUser e)
+ ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unUser a) (unUser b) (unUser c) (unUser e1) (unUser e2)
+ ERecompute _ e -> ERecompute ext (unUser e)
+ EWith _ t a b -> EWith ext t (unUser a) (unUser b)
+ EAccum _ t p eidx sp eval eacc ->
+ accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
+ acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
+ EAccum ext t prj' (unUser idx') (spDense (acPrjTy prj' t)) (unUser val2) (weakenExpr w (unUser eacc))
+ EError _ t s -> EError ext t s
+
+ EZero{} -> err_monoid
+ EDeepZero{} -> err_monoid
+ EPlus{} -> err_monoid
+ EOneHot{} -> err_monoid
+ where
+ err_monoid = error "unUser: Monoid ops found"