{-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module CHAD.AST.UnUser where import CHAD.AST type family UnUser t where UnUser TNil = TNil UnUser (TPair a b) = TPair (UnUser a) (UnUser b) UnUser (TEither a b) = TEither (UnUser a) (UnUser b) UnUser (TLEither a b) = TLEither (UnUser a) (UnUser b) UnUser (TMaybe a) = TMaybe (UnUser a) UnUser (TArr n t) = TArr n (UnUser t) UnUser (TScal t) = TScal t UnUser (TAccum t) = TAccum (UnUser t) UnUser (TUser t) = UnUser (UserRep t) type family UnUserE env where UnUserE '[] = '[] UnUserE (t : ts) = UnUser t : UnUserE ts unUserTy :: STy t -> STy (UnUser t) unUserTy = \case STNil -> STNil STPair a b -> STPair (unUserTy a) (unUserTy b) STEither a b -> STEither (unUserTy a) (unUserTy b) STLEither a b -> STLEither (unUserTy a) (unUserTy b) STMaybe t -> STMaybe (unUserTy t) STArr n t -> STArr n (unUserTy t) STScal t -> STScal t STAccum t -> STAccum (unUserMTy t) STUser t -> unUserTy (userRepTy t) unUserMTy :: SMTy t -> SMTy (UnUser t) unUserMTy = \case SMTNil -> SMTNil SMTPair a b -> SMTPair (unUserMTy a) (unUserMTy b) SMTLEither a b -> SMTLEither (unUserMTy a) (unUserMTy b) SMTMaybe t -> SMTMaybe (unUserMTy t) SMTArr n t -> SMTArr n (unUserMTy t) SMTScal t -> SMTScal t SMTUser t -> unUserMTy (userRepTy t) unUser :: Ex env t -> Ex (UnUserE env) (UnUser t) unUser = \case EUser _ _ e -> unUser e EUnUser _ e -> unUser e EVar _ t i -> EVar ext t (goIdx i) ELet _ rhs body -> ELet ext (unUser rhs) (unUser body) EPair _ a b -> EPair ext (unUser a) (unUser b) EFst _ e -> EFst ext (unUser e) ESnd _ e -> ESnd ext (unUser e) ENil _ -> ENil ext EInl _ t e -> EInl ext (unUserTy t) (unUser e) EInr _ t e -> EInr ext (unUserTy t) (unUser e) ECase _ e a b -> ECase ext (unUser e) (unUser a) (unUser b) ENothing _ t -> ENothing ext t EJust _ e -> EJust ext (unUser e) EMaybe _ a b e -> EMaybe ext (unUser a) (unUser b) (unUser e) ELNil _ t1 t2 -> ELNil ext t1 t2 ELInl _ t e -> ELInl ext t (unUser e) ELInr _ t e -> ELInr ext t (unUser e) ELCase _ e a b c -> ELCase ext (unUser e) (unUser a) (unUser b) (unUser c) EConstArr _ n t x -> EConstArr ext n t x EBuild _ n a b -> EBuild ext n (unUser a) (unUser b) EMap _ a b -> EMap ext (unUser a) (unUser b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (unUser a) (unUser b) (unUser c) ESum1Inner _ e -> ESum1Inner ext (unUser e) EUnit _ e -> EUnit ext (unUser e) EReplicate1Inner _ a b -> EReplicate1Inner ext (unUser a) (unUser b) EMaximum1Inner _ e -> EMaximum1Inner ext (unUser e) EMinimum1Inner _ e -> EMinimum1Inner ext (unUser e) EReshape _ n a b -> EReshape ext n (unUser a) (unUser b) EZip _ a b -> EZip ext (unUser a) (unUser b) EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unUser a) (unUser b) (unUser c) EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unUser a) (unUser b) (unUser c) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unUser e) EIdx1 _ a b -> EIdx1 ext (unUser a) (unUser b) EIdx _ a b -> EIdx ext (unUser a) (unUser b) EShape _ e -> EShape ext (unUser e) EOp _ op e -> EOp ext op (unUser e) ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unUser a) (unUser b) (unUser c) (unUser e1) (unUser e2) ERecompute _ e -> ERecompute ext (unUser e) EWith _ t a b -> EWith ext t (unUser a) (unUser b) EAccum _ t p eidx sp eval eacc -> accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> EAccum ext t prj' (unUser idx') (spDense (acPrjTy prj' t)) (unUser val2) (weakenExpr w (unUser eacc)) EError _ t s -> EError ext t s EZero{} -> err_monoid EDeepZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid where err_monoid = error "unUser: Monoid ops found"