diff options
Diffstat (limited to 'src/CHAD/AST/UnUser.hs')
| -rw-r--r-- | src/CHAD/AST/UnUser.hs | 102 |
1 files changed, 102 insertions, 0 deletions
diff --git a/src/CHAD/AST/UnUser.hs b/src/CHAD/AST/UnUser.hs new file mode 100644 index 0000000..73b216c --- /dev/null +++ b/src/CHAD/AST/UnUser.hs @@ -0,0 +1,102 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +module CHAD.AST.UnUser where + +import CHAD.AST + + +type family UnUser t where + UnUser TNil = TNil + UnUser (TPair a b) = TPair (UnUser a) (UnUser b) + UnUser (TEither a b) = TEither (UnUser a) (UnUser b) + UnUser (TLEither a b) = TLEither (UnUser a) (UnUser b) + UnUser (TMaybe a) = TMaybe (UnUser a) + UnUser (TArr n t) = TArr n (UnUser t) + UnUser (TScal t) = TScal t + UnUser (TAccum t) = TAccum (UnUser t) + UnUser (TUser t) = UnUser (UserRep t) + +type family UnUserE env where + UnUserE '[] = '[] + UnUserE (t : ts) = UnUser t : UnUserE ts + +unUserTy :: STy t -> STy (UnUser t) +unUserTy = \case + STNil -> STNil + STPair a b -> STPair (unUserTy a) (unUserTy b) + STEither a b -> STEither (unUserTy a) (unUserTy b) + STLEither a b -> STLEither (unUserTy a) (unUserTy b) + STMaybe t -> STMaybe (unUserTy t) + STArr n t -> STArr n (unUserTy t) + STScal t -> STScal t + STAccum t -> STAccum (unUserMTy t) + STUser t -> unUserTy (userRepTy t) + +unUserMTy :: SMTy t -> SMTy (UnUser t) +unUserMTy = \case + SMTNil -> SMTNil + SMTPair a b -> SMTPair (unUserMTy a) (unUserMTy b) + SMTLEither a b -> SMTLEither (unUserMTy a) (unUserMTy b) + SMTMaybe t -> SMTMaybe (unUserMTy t) + SMTArr n t -> SMTArr n (unUserMTy t) + SMTScal t -> SMTScal t + SMTUser t -> unUserMTy (userRepTy t) + +unUser :: Ex env t -> Ex (UnUserE env) (UnUser t) +unUser = \case + EUser _ _ e -> unUser e + EUnUser _ e -> unUser e + + EVar _ t i -> EVar ext t (goIdx i) + ELet _ rhs body -> ELet ext (unUser rhs) (unUser body) + EPair _ a b -> EPair ext (unUser a) (unUser b) + EFst _ e -> EFst ext (unUser e) + ESnd _ e -> ESnd ext (unUser e) + ENil _ -> ENil ext + EInl _ t e -> EInl ext (unUserTy t) (unUser e) + EInr _ t e -> EInr ext (unUserTy t) (unUser e) + ECase _ e a b -> ECase ext (unUser e) (unUser a) (unUser b) + ENothing _ t -> ENothing ext t + EJust _ e -> EJust ext (unUser e) + EMaybe _ a b e -> EMaybe ext (unUser a) (unUser b) (unUser e) + ELNil _ t1 t2 -> ELNil ext t1 t2 + ELInl _ t e -> ELInl ext t (unUser e) + ELInr _ t e -> ELInr ext t (unUser e) + ELCase _ e a b c -> ELCase ext (unUser e) (unUser a) (unUser b) (unUser c) + EConstArr _ n t x -> EConstArr ext n t x + EBuild _ n a b -> EBuild ext n (unUser a) (unUser b) + EMap _ a b -> EMap ext (unUser a) (unUser b) + EFold1Inner _ cm a b c -> EFold1Inner ext cm (unUser a) (unUser b) (unUser c) + ESum1Inner _ e -> ESum1Inner ext (unUser e) + EUnit _ e -> EUnit ext (unUser e) + EReplicate1Inner _ a b -> EReplicate1Inner ext (unUser a) (unUser b) + EMaximum1Inner _ e -> EMaximum1Inner ext (unUser e) + EMinimum1Inner _ e -> EMinimum1Inner ext (unUser e) + EReshape _ n a b -> EReshape ext n (unUser a) (unUser b) + EZip _ a b -> EZip ext (unUser a) (unUser b) + EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unUser a) (unUser b) (unUser c) + EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unUser a) (unUser b) (unUser c) + EConst _ t x -> EConst ext t x + EIdx0 _ e -> EIdx0 ext (unUser e) + EIdx1 _ a b -> EIdx1 ext (unUser a) (unUser b) + EIdx _ a b -> EIdx ext (unUser a) (unUser b) + EShape _ e -> EShape ext (unUser e) + EOp _ op e -> EOp ext op (unUser e) + ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unUser a) (unUser b) (unUser c) (unUser e1) (unUser e2) + ERecompute _ e -> ERecompute ext (unUser e) + EWith _ t a b -> EWith ext t (unUser a) (unUser b) + EAccum _ t p eidx sp eval eacc -> + accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> + acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> + EAccum ext t prj' (unUser idx') (spDense (acPrjTy prj' t)) (unUser val2) (weakenExpr w (unUser eacc)) + EError _ t s -> EError ext t s + + EZero{} -> err_monoid + EDeepZero{} -> err_monoid + EPlus{} -> err_monoid + EOneHot{} -> err_monoid + where + err_monoid = error "unUser: Monoid ops found" |
