diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/AST.hs | 27 | ||||
-rw-r--r-- | src/AST/UnMonoid.hs | 118 |
2 files changed, 132 insertions, 13 deletions
@@ -396,17 +396,24 @@ emap f arr = (EVar ext (tTup (sreplicate n tIx)) IZ)) $ weakenExpr (WCopy (WSink .> WSink)) f -ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) -ezip a b = - let STArr n t1 = typeOf a - STArr _ t2 = typeOf b - in ELet ext a $ - ELet ext (weakenExpr WSink b) $ +ezipWith :: Ex (b : a : env) c -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) +ezipWith f arr1 arr2 = + let STArr n t1 = typeOf arr1 + STArr _ t2 = typeOf arr2 + in ELet ext arr1 $ + ELet ext (weakenExpr WSink arr2) $ EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ - EPair ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) - (EIdx ext (EVar ext (STArr n t2) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) + ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ + weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f + +ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) +ezip arr1 arr2 = + let STArr _ t1 = typeOf arr1 + STArr _ t2 = typeOf arr2 + in ezipWith (EPair ext (EVar ext t1 (IS IZ)) (EVar ext t2 IZ)) arr1 arr2 eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 1675dab..8da1e32 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -13,7 +13,7 @@ unMonoid :: Ex env t -> Ex env t unMonoid = \case EZero t -> zero t EPlus t a b -> plus t a b - EOneHot t i a b -> _ t i a b + EOneHot t i a b -> onehot t i a b EVar _ t i -> EVar ext t i ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) @@ -51,7 +51,8 @@ zero STNil = ENil ext zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2)) zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2)) zero (STMaybe t) = ENothing ext (d2 t) -zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t) +zero (STArr SZ t) = EUnit ext (zero t) +zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError (d2 t) "empty") zero (STScal t) = case t of STI32 -> ENil ext STI64 -> ENil ext @@ -85,7 +86,13 @@ plus (STMaybe t) a b = plus (STArr n t) a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ - ECase + eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ)))) + (EVar ext (STArr n (d2 t)) IZ) + (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ))) + (EVar ext (STArr n (d2 t)) (IS IZ)) + (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)) + (EVar ext (STArr n (d2 t)) (IS IZ)) + (EVar ext (STArr n (d2 t)) IZ))) plus (STScal t) a b = case t of STI32 -> ENil ext STI64 -> ENil ext @@ -108,3 +115,108 @@ plusSparse t a b adder = (weakenExpr (WCopy (WCopy WSink)) adder) (EVar ext (STMaybe t) (IS IZ)))) (weakenExpr WSink a) + +onehot :: STy t -> SNat i -> Ex env (AcIdx (D2 t) i) -> Ex env (AcVal (D2 t) i) -> Ex env (D2 t) +onehot _ SZ _ val = val +onehot t (SS dep) idx val = case t of + STPair t1 t2 -> + case dep of + SZ -> EJust ext val + SS dep' -> + let STEither tidx1 tidx2 = typeOf idx + STEither tval1 tval2 = typeOf val + in EJust ext $ + ECase ext idx + (ECase ext (weakenExpr WSink val) + (EPair ext (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ)) + (zero t2)) + (EError (STPair (d2 t1) (d2 t2)) "onehot pair l/r")) + (ECase ext (weakenExpr WSink val) + (EError (STPair (d2 t1) (d2 t2)) "onehot pair r/l") + (EPair ext (zero t1) + (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) + + STEither t1 t2 -> + case dep of + SZ -> EJust ext val + SS dep' -> + let STEither tidx1 tidx2 = typeOf idx + STEither tval1 tval2 = typeOf val + in EJust ext $ + ECase ext idx + (ECase ext (weakenExpr WSink val) + (EInl ext (d2 t2) (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ))) + (EError (STEither (d2 t1) (d2 t2)) "onehot either l/r")) + (ECase ext (weakenExpr WSink val) + (EError (STEither (d2 t1) (d2 t2)) "onehot either r/l") + (EInr ext (d2 t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) + + STMaybe t1 -> EJust ext (onehot t1 dep idx val) + + STArr n t1 -> + ELet ext val $ + EBuild ext n (EFst ext (EVar ext (typeOf val) IZ)) + (onehotArrayElem t1 n (SS dep) + (EVar ext (tTup (sreplicate n tIx)) IZ) + (weakenExpr (WSink .> WSink) idx) + (ESnd ext (EVar ext (typeOf val) (IS IZ)))) + + STNil -> error "Cannot index into nil" + STScal{} -> error "Cannot index into scalar" + STAccum{} -> error "Accumulators not allowed in input program" + +onehotArrayElem + :: STy t -> SNat n -> SNat i + -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' + -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot + -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole + -> Ex env (D2 t) +onehotArrayElem t n dep eltidx idx val = + ELet ext eltidx $ + ELet ext (weakenExpr WSink idx) $ + let (cond, elt) = onehotArrayElemRec t n dep + (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) + (EVar ext (typeOf idx) IZ) + (weakenExpr (WSink .> WSink) val) + in eif cond elt (zero t) + +-- AcIdx must be duplicable +onehotArrayElemRec + :: STy t -> SNat n -> SNat i + -> [Ex env TIx] + -> Ex env (AcIdx (TArr n (D2 t)) i) + -> Ex env (AcValArr n (D2 t) i) + -> (Ex env (TScal TBool), Ex env (D2 t)) +onehotArrayElemRec _ n SZ eltidx _ val = + (EConst ext STBool True + ,EIdx ext val (reconstructFromOutsideIn n eltidx)) +onehotArrayElemRec t SZ (SS dep) eltidx idx val = + case eltidx of + [] -> (EConst ext STBool True, onehot t dep idx val) + _ -> error "onehotArrayElemRec: mismatched list length" +onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = + case eltidx of + i : eltidx' -> + let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val + in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) + ,elt) + [] -> error "onehotArrayElemRec: mismatched list length" + +-- | Outermost index at the head. The input expression must be duplicable. +outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] +outsideInIndex = \n idx -> go n idx [] + where + go :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -> [Ex env TIx] + go SZ _ acc = acc + go (SS n) idx acc = go n (EFst ext idx) (ESnd ext idx : acc) + +-- Takes a list with the outermost index at the head. Returns a tuple with the +-- innermost index on the right. +reconstructFromOutsideIn :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx)) +reconstructFromOutsideIn = \n list -> go n (reverse list) + where + -- Takes list with the _innermost_ index at the head. + go :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx)) + go SZ [] = ENil ext + go (SS n) (i:is) = EPair ext (go n is) i + go _ _ = error "reconstructFromOutsideIn: mismatched list length" |