aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/AST/UnUser.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD/AST/UnUser.hs')
-rw-r--r--src/CHAD/AST/UnUser.hs102
1 files changed, 102 insertions, 0 deletions
diff --git a/src/CHAD/AST/UnUser.hs b/src/CHAD/AST/UnUser.hs
new file mode 100644
index 0000000..73b216c
--- /dev/null
+++ b/src/CHAD/AST/UnUser.hs
@@ -0,0 +1,102 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+module CHAD.AST.UnUser where
+
+import CHAD.AST
+
+
+type family UnUser t where
+ UnUser TNil = TNil
+ UnUser (TPair a b) = TPair (UnUser a) (UnUser b)
+ UnUser (TEither a b) = TEither (UnUser a) (UnUser b)
+ UnUser (TLEither a b) = TLEither (UnUser a) (UnUser b)
+ UnUser (TMaybe a) = TMaybe (UnUser a)
+ UnUser (TArr n t) = TArr n (UnUser t)
+ UnUser (TScal t) = TScal t
+ UnUser (TAccum t) = TAccum (UnUser t)
+ UnUser (TUser t) = UnUser (UserRep t)
+
+type family UnUserE env where
+ UnUserE '[] = '[]
+ UnUserE (t : ts) = UnUser t : UnUserE ts
+
+unUserTy :: STy t -> STy (UnUser t)
+unUserTy = \case
+ STNil -> STNil
+ STPair a b -> STPair (unUserTy a) (unUserTy b)
+ STEither a b -> STEither (unUserTy a) (unUserTy b)
+ STLEither a b -> STLEither (unUserTy a) (unUserTy b)
+ STMaybe t -> STMaybe (unUserTy t)
+ STArr n t -> STArr n (unUserTy t)
+ STScal t -> STScal t
+ STAccum t -> STAccum (unUserMTy t)
+ STUser t -> unUserTy (userRepTy t)
+
+unUserMTy :: SMTy t -> SMTy (UnUser t)
+unUserMTy = \case
+ SMTNil -> SMTNil
+ SMTPair a b -> SMTPair (unUserMTy a) (unUserMTy b)
+ SMTLEither a b -> SMTLEither (unUserMTy a) (unUserMTy b)
+ SMTMaybe t -> SMTMaybe (unUserMTy t)
+ SMTArr n t -> SMTArr n (unUserMTy t)
+ SMTScal t -> SMTScal t
+ SMTUser t -> unUserMTy (userRepTy t)
+
+unUser :: Ex env t -> Ex (UnUserE env) (UnUser t)
+unUser = \case
+ EUser _ _ e -> unUser e
+ EUnUser _ e -> unUser e
+
+ EVar _ t i -> EVar ext t (goIdx i)
+ ELet _ rhs body -> ELet ext (unUser rhs) (unUser body)
+ EPair _ a b -> EPair ext (unUser a) (unUser b)
+ EFst _ e -> EFst ext (unUser e)
+ ESnd _ e -> ESnd ext (unUser e)
+ ENil _ -> ENil ext
+ EInl _ t e -> EInl ext (unUserTy t) (unUser e)
+ EInr _ t e -> EInr ext (unUserTy t) (unUser e)
+ ECase _ e a b -> ECase ext (unUser e) (unUser a) (unUser b)
+ ENothing _ t -> ENothing ext t
+ EJust _ e -> EJust ext (unUser e)
+ EMaybe _ a b e -> EMaybe ext (unUser a) (unUser b) (unUser e)
+ ELNil _ t1 t2 -> ELNil ext t1 t2
+ ELInl _ t e -> ELInl ext t (unUser e)
+ ELInr _ t e -> ELInr ext t (unUser e)
+ ELCase _ e a b c -> ELCase ext (unUser e) (unUser a) (unUser b) (unUser c)
+ EConstArr _ n t x -> EConstArr ext n t x
+ EBuild _ n a b -> EBuild ext n (unUser a) (unUser b)
+ EMap _ a b -> EMap ext (unUser a) (unUser b)
+ EFold1Inner _ cm a b c -> EFold1Inner ext cm (unUser a) (unUser b) (unUser c)
+ ESum1Inner _ e -> ESum1Inner ext (unUser e)
+ EUnit _ e -> EUnit ext (unUser e)
+ EReplicate1Inner _ a b -> EReplicate1Inner ext (unUser a) (unUser b)
+ EMaximum1Inner _ e -> EMaximum1Inner ext (unUser e)
+ EMinimum1Inner _ e -> EMinimum1Inner ext (unUser e)
+ EReshape _ n a b -> EReshape ext n (unUser a) (unUser b)
+ EZip _ a b -> EZip ext (unUser a) (unUser b)
+ EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unUser a) (unUser b) (unUser c)
+ EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unUser a) (unUser b) (unUser c)
+ EConst _ t x -> EConst ext t x
+ EIdx0 _ e -> EIdx0 ext (unUser e)
+ EIdx1 _ a b -> EIdx1 ext (unUser a) (unUser b)
+ EIdx _ a b -> EIdx ext (unUser a) (unUser b)
+ EShape _ e -> EShape ext (unUser e)
+ EOp _ op e -> EOp ext op (unUser e)
+ ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unUser a) (unUser b) (unUser c) (unUser e1) (unUser e2)
+ ERecompute _ e -> ERecompute ext (unUser e)
+ EWith _ t a b -> EWith ext t (unUser a) (unUser b)
+ EAccum _ t p eidx sp eval eacc ->
+ accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
+ acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
+ EAccum ext t prj' (unUser idx') (spDense (acPrjTy prj' t)) (unUser val2) (weakenExpr w (unUser eacc))
+ EError _ t s -> EError ext t s
+
+ EZero{} -> err_monoid
+ EDeepZero{} -> err_monoid
+ EPlus{} -> err_monoid
+ EOneHot{} -> err_monoid
+ where
+ err_monoid = error "unUser: Monoid ops found"