summaryrefslogtreecommitdiff
path: root/src/AST/UnMonoid.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/UnMonoid.hs')
-rw-r--r--src/AST/UnMonoid.hs118
1 files changed, 115 insertions, 3 deletions
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 1675dab..8da1e32 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -13,7 +13,7 @@ unMonoid :: Ex env t -> Ex env t
unMonoid = \case
EZero t -> zero t
EPlus t a b -> plus t a b
- EOneHot t i a b -> _ t i a b
+ EOneHot t i a b -> onehot t i a b
EVar _ t i -> EVar ext t i
ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
@@ -51,7 +51,8 @@ zero STNil = ENil ext
zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2))
zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2))
zero (STMaybe t) = ENothing ext (d2 t)
-zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t)
+zero (STArr SZ t) = EUnit ext (zero t)
+zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError (d2 t) "empty")
zero (STScal t) = case t of
STI32 -> ENil ext
STI64 -> ENil ext
@@ -85,7 +86,13 @@ plus (STMaybe t) a b =
plus (STArr n t) a b =
ELet ext a $
ELet ext (weakenExpr WSink b) $
- ECase
+ eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ))))
+ (EVar ext (STArr n (d2 t)) IZ)
+ (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ)))
+ (EVar ext (STArr n (d2 t)) (IS IZ))
+ (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ))
+ (EVar ext (STArr n (d2 t)) (IS IZ))
+ (EVar ext (STArr n (d2 t)) IZ)))
plus (STScal t) a b = case t of
STI32 -> ENil ext
STI64 -> ENil ext
@@ -108,3 +115,108 @@ plusSparse t a b adder =
(weakenExpr (WCopy (WCopy WSink)) adder)
(EVar ext (STMaybe t) (IS IZ))))
(weakenExpr WSink a)
+
+onehot :: STy t -> SNat i -> Ex env (AcIdx (D2 t) i) -> Ex env (AcVal (D2 t) i) -> Ex env (D2 t)
+onehot _ SZ _ val = val
+onehot t (SS dep) idx val = case t of
+ STPair t1 t2 ->
+ case dep of
+ SZ -> EJust ext val
+ SS dep' ->
+ let STEither tidx1 tidx2 = typeOf idx
+ STEither tval1 tval2 = typeOf val
+ in EJust ext $
+ ECase ext idx
+ (ECase ext (weakenExpr WSink val)
+ (EPair ext (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ))
+ (zero t2))
+ (EError (STPair (d2 t1) (d2 t2)) "onehot pair l/r"))
+ (ECase ext (weakenExpr WSink val)
+ (EError (STPair (d2 t1) (d2 t2)) "onehot pair r/l")
+ (EPair ext (zero t1)
+ (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ))))
+
+ STEither t1 t2 ->
+ case dep of
+ SZ -> EJust ext val
+ SS dep' ->
+ let STEither tidx1 tidx2 = typeOf idx
+ STEither tval1 tval2 = typeOf val
+ in EJust ext $
+ ECase ext idx
+ (ECase ext (weakenExpr WSink val)
+ (EInl ext (d2 t2) (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ)))
+ (EError (STEither (d2 t1) (d2 t2)) "onehot either l/r"))
+ (ECase ext (weakenExpr WSink val)
+ (EError (STEither (d2 t1) (d2 t2)) "onehot either r/l")
+ (EInr ext (d2 t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ))))
+
+ STMaybe t1 -> EJust ext (onehot t1 dep idx val)
+
+ STArr n t1 ->
+ ELet ext val $
+ EBuild ext n (EFst ext (EVar ext (typeOf val) IZ))
+ (onehotArrayElem t1 n (SS dep)
+ (EVar ext (tTup (sreplicate n tIx)) IZ)
+ (weakenExpr (WSink .> WSink) idx)
+ (ESnd ext (EVar ext (typeOf val) (IS IZ))))
+
+ STNil -> error "Cannot index into nil"
+ STScal{} -> error "Cannot index into scalar"
+ STAccum{} -> error "Accumulators not allowed in input program"
+
+onehotArrayElem
+ :: STy t -> SNat n -> SNat i
+ -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex'
+ -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot
+ -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole
+ -> Ex env (D2 t)
+onehotArrayElem t n dep eltidx idx val =
+ ELet ext eltidx $
+ ELet ext (weakenExpr WSink idx) $
+ let (cond, elt) = onehotArrayElemRec t n dep
+ (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ)))
+ (EVar ext (typeOf idx) IZ)
+ (weakenExpr (WSink .> WSink) val)
+ in eif cond elt (zero t)
+
+-- AcIdx must be duplicable
+onehotArrayElemRec
+ :: STy t -> SNat n -> SNat i
+ -> [Ex env TIx]
+ -> Ex env (AcIdx (TArr n (D2 t)) i)
+ -> Ex env (AcValArr n (D2 t) i)
+ -> (Ex env (TScal TBool), Ex env (D2 t))
+onehotArrayElemRec _ n SZ eltidx _ val =
+ (EConst ext STBool True
+ ,EIdx ext val (reconstructFromOutsideIn n eltidx))
+onehotArrayElemRec t SZ (SS dep) eltidx idx val =
+ case eltidx of
+ [] -> (EConst ext STBool True, onehot t dep idx val)
+ _ -> error "onehotArrayElemRec: mismatched list length"
+onehotArrayElemRec t (SS n) (SS dep) eltidx idx val =
+ case eltidx of
+ i : eltidx' ->
+ let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val
+ in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond)
+ ,elt)
+ [] -> error "onehotArrayElemRec: mismatched list length"
+
+-- | Outermost index at the head. The input expression must be duplicable.
+outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx]
+outsideInIndex = \n idx -> go n idx []
+ where
+ go :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -> [Ex env TIx]
+ go SZ _ acc = acc
+ go (SS n) idx acc = go n (EFst ext idx) (ESnd ext idx : acc)
+
+-- Takes a list with the outermost index at the head. Returns a tuple with the
+-- innermost index on the right.
+reconstructFromOutsideIn :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx))
+reconstructFromOutsideIn = \n list -> go n (reverse list)
+ where
+ -- Takes list with the _innermost_ index at the head.
+ go :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx))
+ go SZ [] = ENil ext
+ go (SS n) (i:is) = EPair ext (go n is) i
+ go _ _ = error "reconstructFromOutsideIn: mismatched list length"