diff options
Diffstat (limited to 'src/AST')
-rw-r--r-- | src/AST/UnMonoid.hs | 118 |
1 files changed, 115 insertions, 3 deletions
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 1675dab..8da1e32 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -13,7 +13,7 @@ unMonoid :: Ex env t -> Ex env t unMonoid = \case EZero t -> zero t EPlus t a b -> plus t a b - EOneHot t i a b -> _ t i a b + EOneHot t i a b -> onehot t i a b EVar _ t i -> EVar ext t i ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) @@ -51,7 +51,8 @@ zero STNil = ENil ext zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2)) zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2)) zero (STMaybe t) = ENothing ext (d2 t) -zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t) +zero (STArr SZ t) = EUnit ext (zero t) +zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError (d2 t) "empty") zero (STScal t) = case t of STI32 -> ENil ext STI64 -> ENil ext @@ -85,7 +86,13 @@ plus (STMaybe t) a b = plus (STArr n t) a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ - ECase + eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ)))) + (EVar ext (STArr n (d2 t)) IZ) + (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ))) + (EVar ext (STArr n (d2 t)) (IS IZ)) + (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)) + (EVar ext (STArr n (d2 t)) (IS IZ)) + (EVar ext (STArr n (d2 t)) IZ))) plus (STScal t) a b = case t of STI32 -> ENil ext STI64 -> ENil ext @@ -108,3 +115,108 @@ plusSparse t a b adder = (weakenExpr (WCopy (WCopy WSink)) adder) (EVar ext (STMaybe t) (IS IZ)))) (weakenExpr WSink a) + +onehot :: STy t -> SNat i -> Ex env (AcIdx (D2 t) i) -> Ex env (AcVal (D2 t) i) -> Ex env (D2 t) +onehot _ SZ _ val = val +onehot t (SS dep) idx val = case t of + STPair t1 t2 -> + case dep of + SZ -> EJust ext val + SS dep' -> + let STEither tidx1 tidx2 = typeOf idx + STEither tval1 tval2 = typeOf val + in EJust ext $ + ECase ext idx + (ECase ext (weakenExpr WSink val) + (EPair ext (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ)) + (zero t2)) + (EError (STPair (d2 t1) (d2 t2)) "onehot pair l/r")) + (ECase ext (weakenExpr WSink val) + (EError (STPair (d2 t1) (d2 t2)) "onehot pair r/l") + (EPair ext (zero t1) + (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) + + STEither t1 t2 -> + case dep of + SZ -> EJust ext val + SS dep' -> + let STEither tidx1 tidx2 = typeOf idx + STEither tval1 tval2 = typeOf val + in EJust ext $ + ECase ext idx + (ECase ext (weakenExpr WSink val) + (EInl ext (d2 t2) (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ))) + (EError (STEither (d2 t1) (d2 t2)) "onehot either l/r")) + (ECase ext (weakenExpr WSink val) + (EError (STEither (d2 t1) (d2 t2)) "onehot either r/l") + (EInr ext (d2 t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) + + STMaybe t1 -> EJust ext (onehot t1 dep idx val) + + STArr n t1 -> + ELet ext val $ + EBuild ext n (EFst ext (EVar ext (typeOf val) IZ)) + (onehotArrayElem t1 n (SS dep) + (EVar ext (tTup (sreplicate n tIx)) IZ) + (weakenExpr (WSink .> WSink) idx) + (ESnd ext (EVar ext (typeOf val) (IS IZ)))) + + STNil -> error "Cannot index into nil" + STScal{} -> error "Cannot index into scalar" + STAccum{} -> error "Accumulators not allowed in input program" + +onehotArrayElem + :: STy t -> SNat n -> SNat i + -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' + -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot + -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole + -> Ex env (D2 t) +onehotArrayElem t n dep eltidx idx val = + ELet ext eltidx $ + ELet ext (weakenExpr WSink idx) $ + let (cond, elt) = onehotArrayElemRec t n dep + (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) + (EVar ext (typeOf idx) IZ) + (weakenExpr (WSink .> WSink) val) + in eif cond elt (zero t) + +-- AcIdx must be duplicable +onehotArrayElemRec + :: STy t -> SNat n -> SNat i + -> [Ex env TIx] + -> Ex env (AcIdx (TArr n (D2 t)) i) + -> Ex env (AcValArr n (D2 t) i) + -> (Ex env (TScal TBool), Ex env (D2 t)) +onehotArrayElemRec _ n SZ eltidx _ val = + (EConst ext STBool True + ,EIdx ext val (reconstructFromOutsideIn n eltidx)) +onehotArrayElemRec t SZ (SS dep) eltidx idx val = + case eltidx of + [] -> (EConst ext STBool True, onehot t dep idx val) + _ -> error "onehotArrayElemRec: mismatched list length" +onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = + case eltidx of + i : eltidx' -> + let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val + in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) + ,elt) + [] -> error "onehotArrayElemRec: mismatched list length" + +-- | Outermost index at the head. The input expression must be duplicable. +outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] +outsideInIndex = \n idx -> go n idx [] + where + go :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -> [Ex env TIx] + go SZ _ acc = acc + go (SS n) idx acc = go n (EFst ext idx) (ESnd ext idx : acc) + +-- Takes a list with the outermost index at the head. Returns a tuple with the +-- innermost index on the right. +reconstructFromOutsideIn :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx)) +reconstructFromOutsideIn = \n list -> go n (reverse list) + where + -- Takes list with the _innermost_ index at the head. + go :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx)) + go SZ [] = ENil ext + go (SS n) (i:is) = EPair ext (go n is) i + go _ _ = error "reconstructFromOutsideIn: mismatched list length" |