{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeOperators #-} module AST.UnMonoid (unMonoid, zero, plus) where import AST import Data -- | Remove 'EZero', 'EPlus' and 'EOneHot' from the program by expanding them -- into their concrete implementations. unMonoid :: Ex env t -> Ex env t unMonoid = \case EZero _ t e -> zero t e EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) EVar _ t i -> EVar ext t i ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) EFst _ e -> EFst ext (unMonoid e) ESnd _ e -> ESnd ext (unMonoid e) ENil _ -> ENil ext EInl _ t e -> EInl ext t (unMonoid e) EInr _ t e -> EInr ext t (unMonoid e) ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) ENothing _ t -> ENothing ext t EJust _ e -> EJust ext (unMonoid e) EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) ELNil _ t1 t2 -> ELNil ext t1 t2 ELInl _ t e -> ELInl ext t (unMonoid e) ELInr _ t e -> ELInr ext t (unMonoid e) ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) EConstArr _ n t x -> EConstArr ext n t x EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) ESum1Inner _ e -> ESum1Inner ext (unMonoid e) EUnit _ e -> EUnit ext (unMonoid e) EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b) EShape _ e -> EShape ext (unMonoid e) EOp _ op e -> EOp ext op (unMonoid e) ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e) EError _ t s -> EError ext t s zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t zero SMTNil _ = ENil ext zero (SMTPair t1 t2) e = ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2) zero (SMTMaybe t) _ = ENothing ext (fromSMTy t) zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e zero (SMTScal t) _ = case t of STI32 -> EConst ext STI32 0 STI64 -> EConst ext STI64 0 STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t plus SMTNil _ _ = ENil ext plus (SMTPair t1 t2) a b = let t = STPair (fromSMTy t1) (fromSMTy t2) in ELet ext a $ ELet ext (weakenExpr WSink b) $ EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) (EFst ext (EVar ext t IZ))) (plus t2 (ESnd ext (EVar ext t (IS IZ))) (ESnd ext (EVar ext t IZ))) plus (SMTLEither t1 t2) a b = let t = STLEither (fromSMTy t1) (fromSMTy t2) in ELet ext a $ ELet ext (weakenExpr WSink b) $ ELCase ext (EVar ext t (IS IZ)) (EVar ext t IZ) (ELCase ext (EVar ext t (IS IZ)) (EVar ext t (IS (IS IZ))) (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ))) (EError ext t "plus l+r")) (ELCase ext (EVar ext t (IS IZ)) (EVar ext t (IS (IS IZ))) (EError ext t "plus r+l") (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ)))) plus (SMTMaybe t) a b = ELet ext b $ EMaybe ext (EVar ext (STMaybe (fromSMTy t)) IZ) (EJust ext (EMaybe ext (EVar ext (fromSMTy t) IZ) (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) (EVar ext (STMaybe (fromSMTy t)) (IS IZ)))) (weakenExpr WSink a) plus (SMTArr _ t) a b = ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) a b plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env t onehot typ topprj idx arg = case (typ, topprj) of (_, SAPHere) -> ELet ext arg $ EVar ext (fromSMTy typ) IZ (SMTPair t1 t2, SAPFst prj) -> ELet ext idx $ let tidx = typeOf idx in ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ let toh = fromSMTy t1 in EPair ext (EVar ext toh IZ) (zero t2 (ESnd ext (EVar ext tidx (IS IZ)))) (SMTPair t1 t2, SAPSnd prj) -> ELet ext idx $ let tidx = typeOf idx in ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ let toh = fromSMTy t2 in EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ)))) (EVar ext toh IZ) (SMTLEither t1 t2, SAPLeft prj) -> ELInl ext (fromSMTy t2) (onehot t1 prj idx arg) (SMTLEither t1 t2, SAPRight prj) -> ELInr ext (fromSMTy t1) (onehot t2 prj idx arg) (SMTMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg) (SMTArr n t1, SAPArrIdx prj) -> let tidx = tTup (sreplicate n tIx) in ELet ext idx $ EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $ eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ zero t1 (EVar ext (tZeroInfo t1) IZ))