{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeOperators #-} module AST.UnMonoid where import AST import CHAD.Types import Data unMonoid :: Ex env t -> Ex env t unMonoid = \case EZero _ t -> zero t EPlus _ t a b -> plus t a b EOneHot _ t i a b -> onehot t i a b EVar _ t i -> EVar ext t i ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) EFst _ e -> EFst ext (unMonoid e) ESnd _ e -> ESnd ext (unMonoid e) ENil _ -> ENil ext EInl _ t e -> EInl ext t (unMonoid e) EInr _ t e -> EInr ext t (unMonoid e) ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) ENothing _ t -> ENothing ext t EJust _ e -> EJust ext (unMonoid e) EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) EConstArr _ n t x -> EConstArr ext n t x EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) EFold1Inner _ a b c -> EFold1Inner ext (unMonoid a) (unMonoid b) (unMonoid c) ESum1Inner _ e -> ESum1Inner ext (unMonoid e) EUnit _ e -> EUnit ext (unMonoid e) EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) EConst _ t x -> EConst ext t x EIdx0 _ e -> EIdx0 ext (unMonoid e) EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b) EShape _ e -> EShape ext (unMonoid e) EOp _ op e -> EOp ext op (unMonoid e) ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) EWith _ a b -> EWith ext (unMonoid a) (unMonoid b) EAccum _ n a b e -> EAccum ext n (unMonoid a) (unMonoid b) (unMonoid e) EError _ t s -> EError ext t s zero :: STy t -> Ex env (D2 t) zero STNil = ENil ext zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2)) zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2)) zero (STMaybe t) = ENothing ext (d2 t) zero (STArr SZ t) = EUnit ext (zero t) zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError ext (d2 t) "empty") zero (STScal t) = case t of STI32 -> ENil ext STI64 -> ENil ext STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 STBool -> ENil ext zero STAccum{} = error "Accumulators not allowed in input program" plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t) plus STNil _ _ = ENil ext plus (STPair t1 t2) a b = let t = STPair (d2 t1) (d2 t2) in plusSparse t a b $ EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) (EFst ext (EVar ext t IZ))) (plus t2 (ESnd ext (EVar ext t (IS IZ))) (ESnd ext (EVar ext t IZ))) plus (STEither t1 t2) a b = let t = STEither (d2 t1) (d2 t2) in plusSparse t a b $ ECase ext (EVar ext t (IS IZ)) (ECase ext (EVar ext t (IS IZ)) (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ))) (EError ext t "plus l+r")) (ECase ext (EVar ext t (IS IZ)) (EError ext t "plus r+l") (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) plus (STMaybe t) a b = plusSparse (d2 t) a b $ plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ) plus (STArr n t) a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ)))) (EVar ext (STArr n (d2 t)) IZ) (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ))) (EVar ext (STArr n (d2 t)) (IS IZ)) (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)) (EVar ext (STArr n (d2 t)) (IS IZ)) (EVar ext (STArr n (d2 t)) IZ))) plus (STScal t) a b = case t of STI32 -> ENil ext STI64 -> ENil ext STF32 -> EOp ext (OAdd STF32) (EPair ext a b) STF64 -> EOp ext (OAdd STF64) (EPair ext a b) STBool -> ENil ext plus STAccum{} _ _ = error "Accumulators not allowed in input program" plusSparse :: STy a -> Ex env (TMaybe a) -> Ex env (TMaybe a) -> Ex (a : a : env) a -> Ex env (TMaybe a) plusSparse t a b adder = ELet ext b $ EMaybe ext (EVar ext (STMaybe t) IZ) (EJust ext (EMaybe ext (EVar ext t IZ) (weakenExpr (WCopy (WCopy WSink)) adder) (EVar ext (STMaybe t) (IS IZ)))) (weakenExpr WSink a) onehot :: STy t -> SNat i -> Ex env (AcIdx (D2 t) i) -> Ex env (AcVal (D2 t) i) -> Ex env (D2 t) onehot _ SZ _ val = val onehot t (SS dep) idx val = case t of STPair t1 t2 -> case dep of SZ -> EJust ext val SS dep' -> let STEither tidx1 tidx2 = typeOf idx STEither tval1 tval2 = typeOf val in EJust ext $ ECase ext idx (ECase ext (weakenExpr WSink val) (EPair ext (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ)) (zero t2)) (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair l/r")) (ECase ext (weakenExpr WSink val) (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair r/l") (EPair ext (zero t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) STEither t1 t2 -> case dep of SZ -> EJust ext val SS dep' -> let STEither tidx1 tidx2 = typeOf idx STEither tval1 tval2 = typeOf val in EJust ext $ ECase ext idx (ECase ext (weakenExpr WSink val) (EInl ext (d2 t2) (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ))) (EError ext (STEither (d2 t1) (d2 t2)) "onehot either l/r")) (ECase ext (weakenExpr WSink val) (EError ext (STEither (d2 t1) (d2 t2)) "onehot either r/l") (EInr ext (d2 t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) STMaybe t1 -> EJust ext (onehot t1 dep idx val) STArr n t1 -> ELet ext val $ EBuild ext n (EFst ext (EVar ext (typeOf val) IZ)) (onehotArrayElem t1 n (SS dep) (EVar ext (tTup (sreplicate n tIx)) IZ) (weakenExpr (WSink .> WSink) idx) (ESnd ext (EVar ext (typeOf val) (IS IZ)))) STNil -> error "Cannot index into nil" STScal{} -> error "Cannot index into scalar" STAccum{} -> error "Accumulators not allowed in input program" onehotArrayElem :: STy t -> SNat n -> SNat i -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole -> Ex env (D2 t) onehotArrayElem t n dep eltidx idx val = ELet ext eltidx $ ELet ext (weakenExpr WSink idx) $ let (cond, elt) = onehotArrayElemRec t n dep (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) (EVar ext (typeOf idx) IZ) (weakenExpr (WSink .> WSink) val) in eif cond elt (zero t) -- AcIdx must be duplicable onehotArrayElemRec :: STy t -> SNat n -> SNat i -> [Ex env TIx] -> Ex env (AcIdx (TArr n (D2 t)) i) -> Ex env (AcValArr n (D2 t) i) -> (Ex env (TScal TBool), Ex env (D2 t)) onehotArrayElemRec _ n SZ eltidx _ val = (EConst ext STBool True ,EIdx ext val (reconstructFromOutsideIn n eltidx)) onehotArrayElemRec t SZ (SS dep) eltidx idx val = case eltidx of [] -> (EConst ext STBool True, onehot t dep idx val) _ -> error "onehotArrayElemRec: mismatched list length" onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = case eltidx of i : eltidx' -> let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) ,elt) [] -> error "onehotArrayElemRec: mismatched list length" -- | Outermost index at the head. The input expression must be duplicable. outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] outsideInIndex = \n idx -> go n idx [] where go :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -> [Ex env TIx] go SZ _ acc = acc go (SS n) idx acc = go n (EFst ext idx) (ESnd ext idx : acc) -- Takes a list with the outermost index at the head. Returns a tuple with the -- innermost index on the right. reconstructFromOutsideIn :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx)) reconstructFromOutsideIn = \n list -> go n (reverse list) where -- Takes list with the _innermost_ index at the head. go :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx)) go SZ [] = ENil ext go (SS n) (i:is) = EPair ext (go n is) i go _ _ = error "reconstructFromOutsideIn: mismatched list length"