{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeOperators #-} module AST.UnMonoid (unMonoid, zero, plus) where import AST import CHAD.Types import Data unMonoid :: Ex env t -> Ex env t unMonoid = \case EZero _ t -> zero t EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) EVar _ t i -> EVar ext t i ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) EFst _ e -> EFst ext (unMonoid e) ESnd _ e -> ESnd ext (unMonoid e) ENil _ -> ENil ext EInl _ t e -> EInl ext t (unMonoid e) EInr _ t e -> EInr ext t (unMonoid e) ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) ENothing _ t -> ENothing ext t EJust _ e -> EJust ext (unMonoid e) EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) EConstArr _ n t x -> EConstArr ext n t x EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) EFold1Inner _ a b c -> EFold1Inner ext (unMonoid a) (unMonoid b) (unMonoid c) ESum1Inner _ e -> ESum1Inner ext (unMonoid e) EUnit _ e -> EUnit ext (unMonoid e) EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b) EShape _ e -> EShape ext (unMonoid e) EOp _ op e -> EOp ext op (unMonoid e) ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e) EError _ t s -> EError ext t s zero :: STy t -> Ex env (D2 t) zero STNil = ENil ext zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2)) zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2)) zero (STMaybe t) = ENothing ext (d2 t) zero (STArr SZ t) = EUnit ext (zero t) zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError ext (d2 t) "empty") zero (STScal t) = case t of STI32 -> ENil ext STI64 -> ENil ext STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 STBool -> ENil ext zero STAccum{} = error "Accumulators not allowed in input program" plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t) plus STNil _ _ = ENil ext plus (STPair t1 t2) a b = let t = STPair (d2 t1) (d2 t2) in plusSparse t a b $ EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) (EFst ext (EVar ext t IZ))) (plus t2 (ESnd ext (EVar ext t (IS IZ))) (ESnd ext (EVar ext t IZ))) plus (STEither t1 t2) a b = let t = STEither (d2 t1) (d2 t2) in plusSparse t a b $ ECase ext (EVar ext t (IS IZ)) (ECase ext (EVar ext t (IS IZ)) (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ))) (EError ext t "plus l+r")) (ECase ext (EVar ext t (IS IZ)) (EError ext t "plus r+l") (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) plus (STMaybe t) a b = plusSparse (d2 t) a b $ plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ) plus (STArr n t) a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ)))) (EVar ext (STArr n (d2 t)) IZ) (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ))) (EVar ext (STArr n (d2 t)) (IS IZ)) (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)) (EVar ext (STArr n (d2 t)) (IS IZ)) (EVar ext (STArr n (d2 t)) IZ))) plus (STScal t) a b = case t of STI32 -> ENil ext STI64 -> ENil ext STF32 -> EOp ext (OAdd STF32) (EPair ext a b) STF64 -> EOp ext (OAdd STF64) (EPair ext a b) STBool -> ENil ext plus STAccum{} _ _ = error "Accumulators not allowed in input program" plusSparse :: STy a -> Ex env (TMaybe a) -> Ex env (TMaybe a) -> Ex (a : a : env) a -> Ex env (TMaybe a) plusSparse t a b adder = ELet ext b $ EMaybe ext (EVar ext (STMaybe t) IZ) (EJust ext (EMaybe ext (EVar ext t IZ) (weakenExpr (WCopy (WCopy WSink)) adder) (EVar ext (STMaybe t) (IS IZ)))) (weakenExpr WSink a) onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t) onehot typ topprj idx arg = case (typ, topprj) of (_, SAPHere) -> arg (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2)) (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg)) (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg)) (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg)) (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg) (STArr n t1, SAPArrIdx prj _) -> let tidx = tTup (sreplicate n tIx) in ELet ext idx $ EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $ eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) (zero t1)