summaryrefslogtreecommitdiff
path: root/src/AST/UnMonoid.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/UnMonoid.hs')
-rw-r--r--src/AST/UnMonoid.hs88
1 files changed, 46 insertions, 42 deletions
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 4b6b523..ec5e11e 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -13,7 +13,7 @@ unMonoid :: Ex env t -> Ex env t
unMonoid = \case
EZero _ t -> zero t
EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
- EOneHot _ t i a b -> onehot t i (unMonoid a) (unMonoid b)
+ EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
EVar _ t i -> EVar ext t i
ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
@@ -42,8 +42,8 @@ unMonoid = \case
EShape _ e -> EShape ext (unMonoid e)
EOp _ op e -> EOp ext op (unMonoid e)
ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
- EWith _ a b -> EWith ext (unMonoid a) (unMonoid b)
- EAccum _ n a b e -> EAccum ext n (unMonoid a) (unMonoid b) (unMonoid e)
+ EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
+ EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)
EError _ t s -> EError ext t s
zero :: STy t -> Ex env (D2 t)
@@ -116,9 +116,13 @@ plusSparse t a b adder =
(EVar ext (STMaybe t) (IS IZ))))
(weakenExpr WSink a)
-onehot :: STy t -> SNat i -> Ex env (AcIdx (D2 t) i) -> Ex env (AcVal (D2 t) i) -> Ex env (D2 t)
-onehot _ SZ _ val = val
-onehot t (SS dep) idx val = case t of
+onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t)
+onehot _ topprj arg = case topprj of
+ SAPHere -> arg
+
+ SAPFst prj -> _
+
+onehot t (SS dep) arg = case t of
STPair t1 t2 ->
case dep of
SZ -> EJust ext val
@@ -165,42 +169,42 @@ onehot t (SS dep) idx val = case t of
STScal{} -> error "Cannot index into scalar"
STAccum{} -> error "Accumulators not allowed in input program"
-onehotArrayElem
- :: STy t -> SNat n -> SNat i
- -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex'
- -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot
- -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole
- -> Ex env (D2 t)
-onehotArrayElem t n dep eltidx idx val =
- ELet ext eltidx $
- ELet ext (weakenExpr WSink idx) $
- let (cond, elt) = onehotArrayElemRec t n dep
- (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ)))
- (EVar ext (typeOf idx) IZ)
- (weakenExpr (WSink .> WSink) val)
- in eif cond elt (zero t)
-
--- AcIdx must be duplicable
-onehotArrayElemRec
- :: STy t -> SNat n -> SNat i
- -> [Ex env TIx]
- -> Ex env (AcIdx (TArr n (D2 t)) i)
- -> Ex env (AcValArr n (D2 t) i)
- -> (Ex env (TScal TBool), Ex env (D2 t))
-onehotArrayElemRec _ n SZ eltidx _ val =
- (EConst ext STBool True
- ,EIdx ext val (reconstructFromOutsideIn n eltidx))
-onehotArrayElemRec t SZ (SS dep) eltidx idx val =
- case eltidx of
- [] -> (EConst ext STBool True, onehot t dep idx val)
- _ -> error "onehotArrayElemRec: mismatched list length"
-onehotArrayElemRec t (SS n) (SS dep) eltidx idx val =
- case eltidx of
- i : eltidx' ->
- let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val
- in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond)
- ,elt)
- [] -> error "onehotArrayElemRec: mismatched list length"
+-- onehotArrayElem
+-- :: STy t -> SNat n -> SNat i
+-- -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex'
+-- -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot
+-- -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole
+-- -> Ex env (D2 t)
+-- onehotArrayElem t n dep eltidx idx val =
+-- ELet ext eltidx $
+-- ELet ext (weakenExpr WSink idx) $
+-- let (cond, elt) = onehotArrayElemRec t n dep
+-- (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ)))
+-- (EVar ext (typeOf idx) IZ)
+-- (weakenExpr (WSink .> WSink) val)
+-- in eif cond elt (zero t)
+
+-- -- AcIdx must be duplicable
+-- onehotArrayElemRec
+-- :: STy t -> SNat n -> SNat i
+-- -> [Ex env TIx]
+-- -> Ex env (AcIdx (TArr n (D2 t)) i)
+-- -> Ex env (AcValArr n (D2 t) i)
+-- -> (Ex env (TScal TBool), Ex env (D2 t))
+-- onehotArrayElemRec _ n SZ eltidx _ val =
+-- (EConst ext STBool True
+-- ,EIdx ext val (reconstructFromOutsideIn n eltidx))
+-- onehotArrayElemRec t SZ (SS dep) eltidx idx val =
+-- case eltidx of
+-- [] -> (EConst ext STBool True, onehot t dep idx val)
+-- _ -> error "onehotArrayElemRec: mismatched list length"
+-- onehotArrayElemRec t (SS n) (SS dep) eltidx idx val =
+-- case eltidx of
+-- i : eltidx' ->
+-- let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val
+-- in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond)
+-- ,elt)
+-- [] -> error "onehotArrayElemRec: mismatched list length"
-- | Outermost index at the head. The input expression must be duplicable.
outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx]