diff options
Diffstat (limited to 'src/AST/UnMonoid.hs')
-rw-r--r-- | src/AST/UnMonoid.hs | 119 |
1 files changed, 15 insertions, 104 deletions
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index ec5e11e..ae9728a 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -2,7 +2,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeOperators #-} -module AST.UnMonoid where +module AST.UnMonoid (unMonoid, zero, plus) where import AST import CHAD.Types @@ -117,110 +117,21 @@ plusSparse t a b adder = (weakenExpr WSink a) onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t) -onehot _ topprj arg = case topprj of - SAPHere -> arg +onehot typ topprj idx arg = case (typ, topprj) of + (_, SAPHere) -> arg - SAPFst prj -> _ + (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2)) + (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg)) -onehot t (SS dep) arg = case t of - STPair t1 t2 -> - case dep of - SZ -> EJust ext val - SS dep' -> - let STEither tidx1 tidx2 = typeOf idx - STEither tval1 tval2 = typeOf val - in EJust ext $ - ECase ext idx - (ECase ext (weakenExpr WSink val) - (EPair ext (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ)) - (zero t2)) - (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair l/r")) - (ECase ext (weakenExpr WSink val) - (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair r/l") - (EPair ext (zero t1) - (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) + (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg)) + (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg)) - STEither t1 t2 -> - case dep of - SZ -> EJust ext val - SS dep' -> - let STEither tidx1 tidx2 = typeOf idx - STEither tval1 tval2 = typeOf val - in EJust ext $ - ECase ext idx - (ECase ext (weakenExpr WSink val) - (EInl ext (d2 t2) (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ))) - (EError ext (STEither (d2 t1) (d2 t2)) "onehot either l/r")) - (ECase ext (weakenExpr WSink val) - (EError ext (STEither (d2 t1) (d2 t2)) "onehot either r/l") - (EInr ext (d2 t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) + (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg) - STMaybe t1 -> EJust ext (onehot t1 dep idx val) - - STArr n t1 -> - ELet ext val $ - EBuild ext n (EFst ext (EVar ext (typeOf val) IZ)) - (onehotArrayElem t1 n (SS dep) - (EVar ext (tTup (sreplicate n tIx)) IZ) - (weakenExpr (WSink .> WSink) idx) - (ESnd ext (EVar ext (typeOf val) (IS IZ)))) - - STNil -> error "Cannot index into nil" - STScal{} -> error "Cannot index into scalar" - STAccum{} -> error "Accumulators not allowed in input program" - --- onehotArrayElem --- :: STy t -> SNat n -> SNat i --- -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' --- -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot --- -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole --- -> Ex env (D2 t) --- onehotArrayElem t n dep eltidx idx val = --- ELet ext eltidx $ --- ELet ext (weakenExpr WSink idx) $ --- let (cond, elt) = onehotArrayElemRec t n dep --- (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) --- (EVar ext (typeOf idx) IZ) --- (weakenExpr (WSink .> WSink) val) --- in eif cond elt (zero t) - --- -- AcIdx must be duplicable --- onehotArrayElemRec --- :: STy t -> SNat n -> SNat i --- -> [Ex env TIx] --- -> Ex env (AcIdx (TArr n (D2 t)) i) --- -> Ex env (AcValArr n (D2 t) i) --- -> (Ex env (TScal TBool), Ex env (D2 t)) --- onehotArrayElemRec _ n SZ eltidx _ val = --- (EConst ext STBool True --- ,EIdx ext val (reconstructFromOutsideIn n eltidx)) --- onehotArrayElemRec t SZ (SS dep) eltidx idx val = --- case eltidx of --- [] -> (EConst ext STBool True, onehot t dep idx val) --- _ -> error "onehotArrayElemRec: mismatched list length" --- onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = --- case eltidx of --- i : eltidx' -> --- let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val --- in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) --- ,elt) --- [] -> error "onehotArrayElemRec: mismatched list length" - --- | Outermost index at the head. The input expression must be duplicable. -outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -outsideInIndex = \n idx -> go n idx [] - where - go :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -> [Ex env TIx] - go SZ _ acc = acc - go (SS n) idx acc = go n (EFst ext idx) (ESnd ext idx : acc) - --- Takes a list with the outermost index at the head. Returns a tuple with the --- innermost index on the right. -reconstructFromOutsideIn :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx)) -reconstructFromOutsideIn = \n list -> go n (reverse list) - where - -- Takes list with the _innermost_ index at the head. - go :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx)) - go SZ [] = ENil ext - go (SS n) (i:is) = EPair ext (go n is) i - go _ _ = error "reconstructFromOutsideIn: mismatched list length" + (STArr n t1, SAPArrIdx prj _) -> + let tidx = tTup (sreplicate n tIx) + in ELet ext idx $ + EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $ + eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) + (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) + (zero t1) |