diff options
Diffstat (limited to 'src/AST/UnMonoid.hs')
-rw-r--r-- | src/AST/UnMonoid.hs | 88 |
1 files changed, 46 insertions, 42 deletions
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 4b6b523..ec5e11e 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -13,7 +13,7 @@ unMonoid :: Ex env t -> Ex env t unMonoid = \case EZero _ t -> zero t EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) - EOneHot _ t i a b -> onehot t i (unMonoid a) (unMonoid b) + EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) EVar _ t i -> EVar ext t i ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) @@ -42,8 +42,8 @@ unMonoid = \case EShape _ e -> EShape ext (unMonoid e) EOp _ op e -> EOp ext op (unMonoid e) ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) - EWith _ a b -> EWith ext (unMonoid a) (unMonoid b) - EAccum _ n a b e -> EAccum ext n (unMonoid a) (unMonoid b) (unMonoid e) + EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) + EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e) EError _ t s -> EError ext t s zero :: STy t -> Ex env (D2 t) @@ -116,9 +116,13 @@ plusSparse t a b adder = (EVar ext (STMaybe t) (IS IZ)))) (weakenExpr WSink a) -onehot :: STy t -> SNat i -> Ex env (AcIdx (D2 t) i) -> Ex env (AcVal (D2 t) i) -> Ex env (D2 t) -onehot _ SZ _ val = val -onehot t (SS dep) idx val = case t of +onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t) +onehot _ topprj arg = case topprj of + SAPHere -> arg + + SAPFst prj -> _ + +onehot t (SS dep) arg = case t of STPair t1 t2 -> case dep of SZ -> EJust ext val @@ -165,42 +169,42 @@ onehot t (SS dep) idx val = case t of STScal{} -> error "Cannot index into scalar" STAccum{} -> error "Accumulators not allowed in input program" -onehotArrayElem - :: STy t -> SNat n -> SNat i - -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' - -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot - -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole - -> Ex env (D2 t) -onehotArrayElem t n dep eltidx idx val = - ELet ext eltidx $ - ELet ext (weakenExpr WSink idx) $ - let (cond, elt) = onehotArrayElemRec t n dep - (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) - (EVar ext (typeOf idx) IZ) - (weakenExpr (WSink .> WSink) val) - in eif cond elt (zero t) - --- AcIdx must be duplicable -onehotArrayElemRec - :: STy t -> SNat n -> SNat i - -> [Ex env TIx] - -> Ex env (AcIdx (TArr n (D2 t)) i) - -> Ex env (AcValArr n (D2 t) i) - -> (Ex env (TScal TBool), Ex env (D2 t)) -onehotArrayElemRec _ n SZ eltidx _ val = - (EConst ext STBool True - ,EIdx ext val (reconstructFromOutsideIn n eltidx)) -onehotArrayElemRec t SZ (SS dep) eltidx idx val = - case eltidx of - [] -> (EConst ext STBool True, onehot t dep idx val) - _ -> error "onehotArrayElemRec: mismatched list length" -onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = - case eltidx of - i : eltidx' -> - let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val - in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) - ,elt) - [] -> error "onehotArrayElemRec: mismatched list length" +-- onehotArrayElem +-- :: STy t -> SNat n -> SNat i +-- -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' +-- -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot +-- -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole +-- -> Ex env (D2 t) +-- onehotArrayElem t n dep eltidx idx val = +-- ELet ext eltidx $ +-- ELet ext (weakenExpr WSink idx) $ +-- let (cond, elt) = onehotArrayElemRec t n dep +-- (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) +-- (EVar ext (typeOf idx) IZ) +-- (weakenExpr (WSink .> WSink) val) +-- in eif cond elt (zero t) + +-- -- AcIdx must be duplicable +-- onehotArrayElemRec +-- :: STy t -> SNat n -> SNat i +-- -> [Ex env TIx] +-- -> Ex env (AcIdx (TArr n (D2 t)) i) +-- -> Ex env (AcValArr n (D2 t) i) +-- -> (Ex env (TScal TBool), Ex env (D2 t)) +-- onehotArrayElemRec _ n SZ eltidx _ val = +-- (EConst ext STBool True +-- ,EIdx ext val (reconstructFromOutsideIn n eltidx)) +-- onehotArrayElemRec t SZ (SS dep) eltidx idx val = +-- case eltidx of +-- [] -> (EConst ext STBool True, onehot t dep idx val) +-- _ -> error "onehotArrayElemRec: mismatched list length" +-- onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = +-- case eltidx of +-- i : eltidx' -> +-- let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val +-- in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) +-- ,elt) +-- [] -> error "onehotArrayElemRec: mismatched list length" -- | Outermost index at the head. The input expression must be duplicable. outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] |