summaryrefslogtreecommitdiff
path: root/src/AST/UnMonoid.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/UnMonoid.hs')
-rw-r--r--src/AST/UnMonoid.hs119
1 files changed, 15 insertions, 104 deletions
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index ec5e11e..ae9728a 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -2,7 +2,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeOperators #-}
-module AST.UnMonoid where
+module AST.UnMonoid (unMonoid, zero, plus) where
import AST
import CHAD.Types
@@ -117,110 +117,21 @@ plusSparse t a b adder =
(weakenExpr WSink a)
onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t)
-onehot _ topprj arg = case topprj of
- SAPHere -> arg
+onehot typ topprj idx arg = case (typ, topprj) of
+ (_, SAPHere) -> arg
- SAPFst prj -> _
+ (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2))
+ (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg))
-onehot t (SS dep) arg = case t of
- STPair t1 t2 ->
- case dep of
- SZ -> EJust ext val
- SS dep' ->
- let STEither tidx1 tidx2 = typeOf idx
- STEither tval1 tval2 = typeOf val
- in EJust ext $
- ECase ext idx
- (ECase ext (weakenExpr WSink val)
- (EPair ext (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ))
- (zero t2))
- (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair l/r"))
- (ECase ext (weakenExpr WSink val)
- (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair r/l")
- (EPair ext (zero t1)
- (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ))))
+ (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg))
+ (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg))
- STEither t1 t2 ->
- case dep of
- SZ -> EJust ext val
- SS dep' ->
- let STEither tidx1 tidx2 = typeOf idx
- STEither tval1 tval2 = typeOf val
- in EJust ext $
- ECase ext idx
- (ECase ext (weakenExpr WSink val)
- (EInl ext (d2 t2) (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ)))
- (EError ext (STEither (d2 t1) (d2 t2)) "onehot either l/r"))
- (ECase ext (weakenExpr WSink val)
- (EError ext (STEither (d2 t1) (d2 t2)) "onehot either r/l")
- (EInr ext (d2 t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ))))
+ (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg)
- STMaybe t1 -> EJust ext (onehot t1 dep idx val)
-
- STArr n t1 ->
- ELet ext val $
- EBuild ext n (EFst ext (EVar ext (typeOf val) IZ))
- (onehotArrayElem t1 n (SS dep)
- (EVar ext (tTup (sreplicate n tIx)) IZ)
- (weakenExpr (WSink .> WSink) idx)
- (ESnd ext (EVar ext (typeOf val) (IS IZ))))
-
- STNil -> error "Cannot index into nil"
- STScal{} -> error "Cannot index into scalar"
- STAccum{} -> error "Accumulators not allowed in input program"
-
--- onehotArrayElem
--- :: STy t -> SNat n -> SNat i
--- -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex'
--- -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot
--- -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole
--- -> Ex env (D2 t)
--- onehotArrayElem t n dep eltidx idx val =
--- ELet ext eltidx $
--- ELet ext (weakenExpr WSink idx) $
--- let (cond, elt) = onehotArrayElemRec t n dep
--- (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ)))
--- (EVar ext (typeOf idx) IZ)
--- (weakenExpr (WSink .> WSink) val)
--- in eif cond elt (zero t)
-
--- -- AcIdx must be duplicable
--- onehotArrayElemRec
--- :: STy t -> SNat n -> SNat i
--- -> [Ex env TIx]
--- -> Ex env (AcIdx (TArr n (D2 t)) i)
--- -> Ex env (AcValArr n (D2 t) i)
--- -> (Ex env (TScal TBool), Ex env (D2 t))
--- onehotArrayElemRec _ n SZ eltidx _ val =
--- (EConst ext STBool True
--- ,EIdx ext val (reconstructFromOutsideIn n eltidx))
--- onehotArrayElemRec t SZ (SS dep) eltidx idx val =
--- case eltidx of
--- [] -> (EConst ext STBool True, onehot t dep idx val)
--- _ -> error "onehotArrayElemRec: mismatched list length"
--- onehotArrayElemRec t (SS n) (SS dep) eltidx idx val =
--- case eltidx of
--- i : eltidx' ->
--- let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val
--- in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond)
--- ,elt)
--- [] -> error "onehotArrayElemRec: mismatched list length"
-
--- | Outermost index at the head. The input expression must be duplicable.
-outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx]
-outsideInIndex = \n idx -> go n idx []
- where
- go :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -> [Ex env TIx]
- go SZ _ acc = acc
- go (SS n) idx acc = go n (EFst ext idx) (ESnd ext idx : acc)
-
--- Takes a list with the outermost index at the head. Returns a tuple with the
--- innermost index on the right.
-reconstructFromOutsideIn :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx))
-reconstructFromOutsideIn = \n list -> go n (reverse list)
- where
- -- Takes list with the _innermost_ index at the head.
- go :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx))
- go SZ [] = ENil ext
- go (SS n) (i:is) = EPair ext (go n is) i
- go _ _ = error "reconstructFromOutsideIn: mismatched list length"
+ (STArr n t1, SAPArrIdx prj _) ->
+ let tidx = tTup (sreplicate n tIx)
+ in ELet ext idx $
+ EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $
+ eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
+ (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
+ (zero t1)