summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-14 23:29:51 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-14 23:29:51 +0100
commitfff6beda3523abce3d27037ea2fb020fce31f502 (patch)
treefbcfb97a8eb2366ec46e0612b4b308741a8b601e
parent137eaa13144c2599ac29da9ebd3af24ac1ce8968 (diff)
Much process with accumulator revamp
-rw-r--r--src/AST.hs6
-rw-r--r--src/AST/Accum.hs33
-rw-r--r--src/AST/Pretty.hs29
-rw-r--r--src/AST/Types.hs2
-rw-r--r--src/AST/UnMonoid.hs119
-rw-r--r--src/Analysis/Identity.hs8
-rw-r--r--src/CHAD/Accum.hs2
-rw-r--r--src/CHAD/Types.hs4
-rw-r--r--src/Interpreter.hs267
-rw-r--r--src/Interpreter/Rep.hs35
-rw-r--r--src/Simplify.hs59
11 files changed, 290 insertions, 274 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 3fb8822..1cdd710 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -88,8 +88,8 @@ data Expr x env t where
-> Expr x env t
-- accumulation effect on monoids
- EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum (D2 t) : env) a -> Expr x env (TPair a (D2 t))
- EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum (D2 a)) -> Expr x env TNil
+ EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum t : env) a -> Expr x env (TPair a (D2 t))
+ EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum t) -> Expr x env TNil
-- monoidal operations (to be desugared to regular operations after simplification)
EZero :: x (D2 t) -> STy t -> Expr x env (D2 t)
@@ -381,6 +381,8 @@ ebuildUp1 n sh size f =
eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool)
eidxEq SZ _ _ = EConst ext STBool True
+eidxEq (SS SZ) a b =
+ EOp ext (OEq STI64) (EPair ext (ESnd ext a) (ESnd ext b))
eidxEq (SS n) a b
| let ty = tTup (sreplicate (SS n) tIx)
= ELet ext a $
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
index 163f1c3..6c46ad5 100644
--- a/src/AST/Accum.hs
+++ b/src/AST/Accum.hs
@@ -21,16 +21,17 @@ data AcPrj
| APArrIdx AcPrj
| APArrSlice Nat
--- | @b@ is a small part of @a@, indicated by the projection.
+-- | @b@ is a small part of @a@, indicated by the projection @p@.
data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
SAPHere :: SAcPrj APHere a a
- SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair t a) b
- SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair a t) b
- SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither t a) b
+ SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b
+ SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b
+ SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b
SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b
- SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe t) b
- SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n t) b
- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
+ SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b
+ SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b
+ -- TODO:
+ -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
deriving instance Show (SAcPrj p a b)
type family AcIdx p t where
@@ -40,5 +41,19 @@ type family AcIdx p t where
AcIdx (APLeft p) (TEither a b) = AcIdx p a
AcIdx (APRight p) (TEither a b) = AcIdx p b
AcIdx (APJust p) (TMaybe a) = AcIdx p a
- AcIdx (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx p a)
- AcIdx (APArrSlice m) (TArr n a) = Tup (Replicate m TIx)
+ AcIdx (APArrIdx p) (TArr n a) =
+ -- ((index, array shape), recursive info)
+ TPair (TPair (Tup (Replicate n TIx)) (Tup (Replicate n TIx)))
+ (AcIdx p a)
+ -- AcIdx (APArrSlice m) (TArr n a) =
+ -- -- (index, array shape)
+ -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx))
+
+acPrjTy :: SAcPrj p a b -> STy a -> STy b
+acPrjTy SAPHere t = t
+acPrjTy (SAPFst prj) (STPair t _) = acPrjTy prj t
+acPrjTy (SAPSnd prj) (STPair _ t) = acPrjTy prj t
+acPrjTy (SAPLeft prj) (STEither t _) = acPrjTy prj t
+acPrjTy (SAPRight prj) (STEither _ t) = acPrjTy prj t
+acPrjTy (SAPJust prj) (STMaybe t) = acPrjTy prj t
+acPrjTy (SAPArrIdx prj _) (STArr _ t) = acPrjTy prj t
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index f91aff2..b9406d7 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -235,9 +235,9 @@ ppExpr' d val expr = case expr of
,e1'
,e2']
- EWith _ e1 e2 -> do
+ EWith _ t e1 e2 -> do
e1' <- ppExpr' 11 val e1
- name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2
+ name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2
e2' <- ppExpr' 0 (Const name `SCons` val) e2
return $ ppParen (d > 0) $
group $ flatAlt
@@ -247,12 +247,12 @@ ppExpr' d val expr = case expr of
<> hardline <> e2')
(ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2'])
- EAccum _ i e1 e2 e3 -> do
+ EAccum _ _ prj e1 e2 e3 -> do
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
e3' <- ppExpr' 11 val e3
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (show (fromSNat i)), e1', e2', e3']
+ ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj prj), e1', e2', e3']
EZero _ t -> return $ ppParen (d > 0) $
annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSTy' 11 t
@@ -263,11 +263,11 @@ ppExpr' d val expr = case expr of
return $ ppParen (d > 10) $
ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b']
- EOneHot _ _ i a b -> do
+ EOneHot _ _ prj a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (show (fromSNat i)), a', b']
+ ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj prj), a', b']
EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s)
@@ -300,6 +300,15 @@ ppLam :: [ADoc] -> ADoc -> ADoc
ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"])
<> softline <> body <> ppString ")")
+ppAcPrj :: SAcPrj p a b -> String
+ppAcPrj SAPHere = "@"
+ppAcPrj (SAPFst prj) = "(" ++ ppAcPrj prj ++ ",)"
+ppAcPrj (SAPSnd prj) = "(," ++ ppAcPrj prj ++ ")"
+ppAcPrj (SAPLeft prj) = "(" ++ ppAcPrj prj ++ "|)"
+ppAcPrj (SAPRight prj) = "(|" ++ ppAcPrj prj ++ ")"
+ppAcPrj (SAPJust prj) = "J" ++ ppAcPrj prj
+ppAcPrj (SAPArrIdx prj n) = "[" ++ ppAcPrj prj ++ "]" ++ intSubscript (fromSNat n)
+
ppX :: PrettyX x => Expr x env t -> ADoc
ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr)
@@ -355,6 +364,14 @@ ppParen :: Bool -> Doc x -> Doc x
ppParen True = parens
ppParen False = id
+intSubscript :: Int -> String
+intSubscript = \case 0 -> "₀"
+ n | n < 0 -> '₋' : go (-n) ""
+ | otherwise -> go n ""
+ where go 0 suff = suff
+ go n suff = let (q, r) = n `quotRem` 10
+ in go q ("₀₁₂₃₄₅₆₇₈₉" !! r : suff)
+
data Annot = AKey | AWith | AHighlight | AMonoid | AExt
deriving (Show)
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index be7cffe..0b41671 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -21,7 +21,7 @@ data Ty
| TMaybe Ty
| TArr Nat Ty -- ^ rank, element type
| TScal ScalTy
- | TAccum Ty
+ | TAccum Ty -- ^ the accumulator contains D2 of this type
deriving (Show, Eq, Ord)
data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index ec5e11e..ae9728a 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -2,7 +2,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeOperators #-}
-module AST.UnMonoid where
+module AST.UnMonoid (unMonoid, zero, plus) where
import AST
import CHAD.Types
@@ -117,110 +117,21 @@ plusSparse t a b adder =
(weakenExpr WSink a)
onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t)
-onehot _ topprj arg = case topprj of
- SAPHere -> arg
+onehot typ topprj idx arg = case (typ, topprj) of
+ (_, SAPHere) -> arg
- SAPFst prj -> _
+ (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2))
+ (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg))
-onehot t (SS dep) arg = case t of
- STPair t1 t2 ->
- case dep of
- SZ -> EJust ext val
- SS dep' ->
- let STEither tidx1 tidx2 = typeOf idx
- STEither tval1 tval2 = typeOf val
- in EJust ext $
- ECase ext idx
- (ECase ext (weakenExpr WSink val)
- (EPair ext (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ))
- (zero t2))
- (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair l/r"))
- (ECase ext (weakenExpr WSink val)
- (EError ext (STPair (d2 t1) (d2 t2)) "onehot pair r/l")
- (EPair ext (zero t1)
- (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ))))
+ (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg))
+ (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg))
- STEither t1 t2 ->
- case dep of
- SZ -> EJust ext val
- SS dep' ->
- let STEither tidx1 tidx2 = typeOf idx
- STEither tval1 tval2 = typeOf val
- in EJust ext $
- ECase ext idx
- (ECase ext (weakenExpr WSink val)
- (EInl ext (d2 t2) (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ)))
- (EError ext (STEither (d2 t1) (d2 t2)) "onehot either l/r"))
- (ECase ext (weakenExpr WSink val)
- (EError ext (STEither (d2 t1) (d2 t2)) "onehot either r/l")
- (EInr ext (d2 t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ))))
+ (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg)
- STMaybe t1 -> EJust ext (onehot t1 dep idx val)
-
- STArr n t1 ->
- ELet ext val $
- EBuild ext n (EFst ext (EVar ext (typeOf val) IZ))
- (onehotArrayElem t1 n (SS dep)
- (EVar ext (tTup (sreplicate n tIx)) IZ)
- (weakenExpr (WSink .> WSink) idx)
- (ESnd ext (EVar ext (typeOf val) (IS IZ))))
-
- STNil -> error "Cannot index into nil"
- STScal{} -> error "Cannot index into scalar"
- STAccum{} -> error "Accumulators not allowed in input program"
-
--- onehotArrayElem
--- :: STy t -> SNat n -> SNat i
--- -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex'
--- -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot
--- -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole
--- -> Ex env (D2 t)
--- onehotArrayElem t n dep eltidx idx val =
--- ELet ext eltidx $
--- ELet ext (weakenExpr WSink idx) $
--- let (cond, elt) = onehotArrayElemRec t n dep
--- (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ)))
--- (EVar ext (typeOf idx) IZ)
--- (weakenExpr (WSink .> WSink) val)
--- in eif cond elt (zero t)
-
--- -- AcIdx must be duplicable
--- onehotArrayElemRec
--- :: STy t -> SNat n -> SNat i
--- -> [Ex env TIx]
--- -> Ex env (AcIdx (TArr n (D2 t)) i)
--- -> Ex env (AcValArr n (D2 t) i)
--- -> (Ex env (TScal TBool), Ex env (D2 t))
--- onehotArrayElemRec _ n SZ eltidx _ val =
--- (EConst ext STBool True
--- ,EIdx ext val (reconstructFromOutsideIn n eltidx))
--- onehotArrayElemRec t SZ (SS dep) eltidx idx val =
--- case eltidx of
--- [] -> (EConst ext STBool True, onehot t dep idx val)
--- _ -> error "onehotArrayElemRec: mismatched list length"
--- onehotArrayElemRec t (SS n) (SS dep) eltidx idx val =
--- case eltidx of
--- i : eltidx' ->
--- let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val
--- in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond)
--- ,elt)
--- [] -> error "onehotArrayElemRec: mismatched list length"
-
--- | Outermost index at the head. The input expression must be duplicable.
-outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx]
-outsideInIndex = \n idx -> go n idx []
- where
- go :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -> [Ex env TIx]
- go SZ _ acc = acc
- go (SS n) idx acc = go n (EFst ext idx) (ESnd ext idx : acc)
-
--- Takes a list with the outermost index at the head. Returns a tuple with the
--- innermost index on the right.
-reconstructFromOutsideIn :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx))
-reconstructFromOutsideIn = \n list -> go n (reverse list)
- where
- -- Takes list with the _innermost_ index at the head.
- go :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx))
- go SZ [] = ENil ext
- go (SS n) (i:is) = EPair ext (go n is) i
- go _ _ = error "reconstructFromOutsideIn: mismatched list length"
+ (STArr n t1, SAPArrIdx prj _) ->
+ let tidx = tTup (sreplicate n tIx)
+ in ELet ext idx $
+ EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $
+ eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))))
+ (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
+ (zero t1)
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index 020ca34..5e36dde 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -237,20 +237,20 @@ idana env expr = case expr of
res <- genIds t4
pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5')
- EWith _ e1 e2 -> do
+ EWith _ t e1 e2 -> do
let t1 = typeOf e1
(_, e1') <- idana env e1
x1 <- VIAccum <$> genId
(v2, e2') <- idana (x1 `SCons` env) e2
x2 <- genIds t1
let res = VIPair v2 x2
- pure (res, EWith res e1' e2')
+ pure (res, EWith res t e1' e2')
- EAccum _ i e1 e2 e3 -> do
+ EAccum _ t prj e1 e2 e3 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
(_, e3') <- idana env e3
- pure (VINil, EAccum VINil i e1' e2' e3')
+ pure (VINil, EAccum VINil t prj e1' e2' e3')
EZero _ t -> do
res <- genIds (d2 t)
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs
index 14a1d3b..b61b5ff 100644
--- a/src/CHAD/Accum.hs
+++ b/src/CHAD/Accum.hs
@@ -12,7 +12,7 @@ makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex e
makeAccumulators SNil e = e
makeAccumulators (t `SCons` envpro) e =
makeAccumulators envpro $
- EWith ext (EZero ext t) e
+ EWith ext t (EZero ext t) e
uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
uninvertTup SNil _ e = EPair ext e (ENil ext)
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index a8614cf..e8ec0c9 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -40,7 +40,7 @@ type family D2E env where
type family D2AcE env where
D2AcE '[] = '[]
- D2AcE (t : env) = TAccum (D2 t) : D2AcE env
+ D2AcE (t : env) = TAccum t : D2AcE env
d1 :: STy t -> STy (D1 t)
d1 STNil = STNil
@@ -75,7 +75,7 @@ d2e (t `SCons` ts) = d2 t `SCons` d2e ts
d2ace :: SList STy env -> SList STy (D2AcE env)
d2ace SNil = SNil
-d2ace (t `SCons` ts) = STAccum (d2 t) `SCons` d2ace ts
+d2ace (t `SCons` ts) = STAccum t `SCons` d2ace ts
data CHADConfig = CHADConfig
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index d80a76e..11caac0 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -21,6 +21,7 @@ module Interpreter (
import Control.Monad (foldM, join, when)
import Data.Bifunctor (bimap)
+import Data.Bitraversable (bitraverse)
import Data.Char (isSpace)
import Data.Functor.Identity
import Data.Kind (Type)
@@ -134,26 +135,25 @@ interpret'Rec env = \case
e1' <- interpret' env e1
e2' <- interpret' env e2
interpret' (Value e2' `SCons` Value e1' `SCons` SNil) pr
- EWith _ e1 e2 -> do
+ EWith _ t e1 e2 -> do
initval <- interpret' env e1
- withAccum (typeOf e1) (typeOf e2) initval $ \accum ->
+ withAccum t (typeOf e2) initval $ \accum ->
interpret' (Value accum `SCons` env) e2
- EAccum _ i e1 e2 e3 -> do
- let STAccum t = typeOf e3
+ EAccum _ t p e1 e2 e3 -> do
idx <- interpret' env e1
val <- interpret' env e2
accum <- interpret' env e3
- accumAddSparse t i accum idx val
+ accumAddSparse t p accum idx val
EZero _ t -> do
return $ zeroD2 t
EPlus _ t a b -> do
a' <- interpret' env a
b' <- interpret' env b
return $ addD2s t a' b'
- EOneHot _ t i a b -> do
+ EOneHot _ t p a b -> do
a' <- interpret' env a
b' <- interpret' env b
- return $ onehotD2 i t a' b'
+ return $ onehotD2 p t a' b'
EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s
interpretOp :: SOp a t -> Rep a -> Rep t
@@ -230,44 +230,37 @@ addD2s typ a b = case typ of
STBool -> ()
STAccum{} -> error "Plus of Accum"
-onehotD2 :: SNat i -> STy t -> Rep (AcIdx (D2 t) i) -> Rep (AcVal (D2 t) i) -> Rep (D2 t)
-onehotD2 SZ _ () v = v
-onehotD2 _ STNil _ _ = ()
-onehotD2 (SS SZ ) (STPair _ _ ) () val = Just val
-onehotD2 (SS (SS i)) (STPair t1 t2) (Left idx) (Left val) = Just (onehotD2 i t1 idx val, zeroD2 t2)
-onehotD2 (SS (SS i)) (STPair t1 t2) (Right idx) (Right val) = Just (zeroD2 t1, onehotD2 i t2 idx val)
-onehotD2 (SS _ ) (STPair _ _ ) _ _ = error "onehotD2: pair: mismatched index and value"
-onehotD2 (SS SZ ) (STEither _ _ ) () val = Just val
-onehotD2 (SS (SS i)) (STEither t1 _ ) (Left idx) (Left val) = Just (Left (onehotD2 i t1 idx val))
-onehotD2 (SS (SS i)) (STEither _ t2) (Right idx) (Right val) = Just (Right (onehotD2 i t2 idx val))
-onehotD2 (SS _ ) (STEither _ _ ) _ _ = error "onehotD2: either: mismatched index and value"
-onehotD2 (SS i ) (STMaybe t) idx val = Just (onehotD2 i t idx val)
-onehotD2 (SS i ) (STArr n t) idx val = runIdentity $
- onehotArray (d2 t) (\i' idx' v' -> Identity (onehotD2 i' t idx' v')) (Identity (zeroD2 t)) n (SS i) idx val
-onehotD2 SS{} STScal{} _ _ = error "onehotD2: cannot index into scalar"
-onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator"
-
-withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t)
+onehotD2 :: SAcPrj p a b -> STy a -> Rep (AcIdx p a) -> Rep (D2 b) -> Rep (D2 a)
+onehotD2 SAPHere _ _ val = val
+onehotD2 (SAPFst prj) (STPair a b) idx val = Just (onehotD2 prj a idx val, zeroD2 b)
+onehotD2 (SAPSnd prj) (STPair a b) idx val = Just (zeroD2 a, onehotD2 prj b idx val)
+onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx val))
+onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val))
+onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val)
+onehotD2 (SAPArrIdx prj _) (STArr n a) idx val =
+ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx
+
+withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t))
withAccum t _ initval f = AcM $ do
- accum <- newAcSparse t SZ () initval
+ accum <- newAcSparse t SAPHere () initval
out <- case f accum of AcM m -> m
val <- readAcSparse t accum
return (out, val)
-newAcZero :: STy t -> IO (RepAcSparse t)
+newAcZero :: STy t -> IO (RepAc t)
newAcZero = \case
STNil -> return ()
- STPair t1 t2 -> newIORef =<< (,) <$> newAcZero t1 <*> newAcZero t2
+ STPair{} -> newIORef Nothing
+ STEither{} -> newIORef Nothing
STMaybe _ -> newIORef Nothing
STArr n _ -> newIORef (emptyArray n)
STScal sty -> case sty of
- STI32 -> newIORef 0
- STI64 -> newIORef 0
+ STI32 -> return ()
+ STI64 -> return ()
STF32 -> newIORef 0.0
STF64 -> newIORef 0.0
- STBool -> error "Accumulator of Bool"
+ STBool -> return ()
STAccum{} -> error "Nested accumulators"
- STEither{} -> error "Bare Either in accumulator"
-- | Inverted index: the outermost index is at the /outside/ of this list.
data PartialInvIndex n m where
@@ -322,95 +315,144 @@ piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n
piindexConcat PIIxEnd ix = ix
piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix)
-newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t)
-newAcSparse typ SZ () val = case typ of
- STNil -> return ()
- STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
- STMaybe t -> newIORef =<< traverse (newAcDense t SZ ()) val
- STArr _ t -> newIORef =<< traverse (newAcSparse t SZ ()) val
- STScal{} -> newIORef val
- STAccum{} -> error "Nested accumulators"
- STEither{} -> error "Bare Either in accumulator"
-newAcSparse typ (SS dep) idx val = case typ of
- STNil -> return ()
- STPair t1 t2 -> newIORef =<< case (idx, val) of
- (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2
- (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val'
- _ -> error "Index/value mismatch in newAc pair"
- STMaybe t -> newIORef =<< Just <$> newAcDense t dep idx val
- STArr dim (t :: STy t) -> newIORef =<< newAcArray dim t (SS dep) idx val
- STScal{} -> error "Cannot index into scalar"
- STAccum{} -> error "Nested accumulators"
- STEither{} -> error "Bare Either in accumulator"
+newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a)
+newAcSparse typ prj idx val = case (typ, prj) of
+ (STNil, SAPHere) -> return ()
+ (STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val
+ (STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val
+ (STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val
+ (STArr _ t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val
+ (STScal sty, SAPHere) -> case sty of
+ STI32 -> return ()
+ STI64 -> return ()
+ STF32 -> newIORef val
+ STF64 -> newIORef val
+ STBool -> return ()
-newAcArray :: SNat n -> STy t -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> IO (Array n (RepAcSparse t))
-newAcArray n t = onehotArray t (newAcSparse t) (newAcZero t) n
+ (STPair t1 t2, SAPFst prj') ->
+ newIORef . Just =<< (,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2
+ (STPair t1 t2, SAPSnd prj') ->
+ newIORef . Just =<< (,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val
-onehotArray :: Monad m
- => STy t
- -> (forall n'. SNat n' -> Rep (AcIdx t n') -> Rep (AcVal t n') -> m v) -- ^ the "one"
- -> m v -- ^ generate a zero value for elsewhere
- -> SNat n -> SNat i -> Rep (AcIdx (TArr n t) i) -> Rep (AcVal (TArr n t) i) -> m (Array n v)
-onehotArray _ mkone _ _ SZ _ val =
- traverse (mkone SZ ()) val
-onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do
- let sh = unTupRepIdx ShNil ShCons dim (fst val)
- go mkone dep dim idx (snd val) $ \arr position ->
- arrayGenerateM sh (\i -> case uninvert <$> piindexMatch position (invert i) of
- Just i' -> return $ arr `arrayIndex` i'
- Nothing -> mkzero)
- where
- go :: Monad m
- => (forall n'. SNat n' -> Rep (AcIdx t n') -> Rep (AcVal t n') -> m v)
- -> SNat i -> SNat n -> Rep (AcIdx (TArr n t) i) -> Rep (AcValArr n t i)
- -> (forall n'. Array n' v -> PartialInvIndex n n' -> m r) -> m r
- go mk SZ _ () val' k = arrayMapM (mk SZ ()) val' >>= \arr -> k arr PIIxEnd
- go mk (SS dep') SZ idx' val' k = mk dep' idx' val' >>= \arr -> k (arrayUnit arr) PIIxEnd
- go mk (SS dep') (SS dim') (i, idx') val' k =
- go mk dep' dim' idx' val' $ \arr pish ->
- k arr (PIIxCons (fromIntegral @Int64 @Int i) pish)
-
-newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t)
-newAcDense typ SZ () val = case typ of
- STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
- STEither t1 t2 -> case val of
- Left x -> Left <$> newAcSparse t1 SZ () x
- Right y -> Right <$> newAcSparse t2 SZ () y
- _ -> error "newAcDense: invalid dense type"
-newAcDense typ (SS dep) idx val = case typ of
- STPair t1 t2 ->
- case (idx, val) of
- (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2
- (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val'
- _ -> error "Index/value mismatch in newAc pair"
- STEither t1 t2 ->
- case (idx, val) of
- (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val'
- (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val'
- _ -> error "Index/value mismatch in newAc either"
- _ -> error "newAcDense: invalid dense type"
+ (STEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val
+ (STEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val
+
+ (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val
+
+ (STArr n t, SAPArrIdx prj' _) -> newIORef =<< newAcArray n t prj' idx val
+
+ (STAccum{}, _) -> error "Accumulators not allowed in source program"
-readAcSparse :: STy t -> RepAcSparse t -> IO (Rep t)
+newAcArray :: SNat n -> STy a -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> Rep (D2 b) -> IO (Array n (RepAc a))
+newAcArray n t prj idx val = onehotArray (\idx' -> newAcSparse t prj idx' val) (newAcZero t) n prj idx
+
+onehotArray :: Monad m
+ => (Rep (AcIdx p a) -> m v) -- ^ the "one"
+ -> m v -- ^ the "zero"
+ -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v)
+onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) =
+ let arrindex = unTupRepIdx IxNil IxCons n arrindex'
+ arrsh = unTupRepIdx ShNil ShCons n arrsh'
+ in arrayGenerateM arrsh (\i -> if i == arrindex then mkone idx else mkzero)
+
+-- newAcDense :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAcDense (D2 a))
+-- newAcDense typ SZ () val = case typ of
+-- STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val)
+-- STEither t1 t2 -> case val of
+-- Left x -> Left <$> newAcSparse t1 SZ () x
+-- Right y -> Right <$> newAcSparse t2 SZ () y
+-- _ -> error "newAcDense: invalid dense type"
+-- newAcDense typ (SS dep) idx val = case typ of
+-- STPair t1 t2 ->
+-- case (idx, val) of
+-- (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2
+-- (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val'
+-- _ -> error "Index/value mismatch in newAc pair"
+-- STEither t1 t2 ->
+-- case (idx, val) of
+-- (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val'
+-- (Right idx', Right val') -> Right <$> newAcSparse t2 dep idx' val'
+-- _ -> error "Index/value mismatch in newAc either"
+-- _ -> error "newAcDense: invalid dense type"
+
+readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t))
readAcSparse typ val = case typ of
STNil -> return ()
- STPair t1 t2 -> do
- (a, b) <- readIORef val
- (,) <$> readAcSparse t1 a <*> readAcSparse t2 b
- STMaybe t -> traverse (readAcDense t) =<< readIORef val
+ STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val
+ STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val
+ STMaybe t -> traverse (readAcSparse t) =<< readIORef val
STArr _ t -> traverse (readAcSparse t) =<< readIORef val
- STScal{} -> readIORef val
+ STScal sty -> case sty of
+ STI32 -> return ()
+ STI64 -> return ()
+ STF32 -> readIORef val
+ STF64 -> readIORef val
+ STBool -> return ()
STAccum{} -> error "Nested accumulators"
- STEither{} -> error "Bare Either in accumulator"
-readAcDense :: STy t -> RepAcDense t -> IO (Rep t)
-readAcDense typ val = case typ of
- STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val)
- STEither t1 t2 -> case val of
- Left x -> Left <$> readAcSparse t1 x
- Right y -> Right <$> readAcSparse t2 y
- _ -> error "readAcDense: invalid dense type"
+-- readAcDense :: STy t -> RepAcDense t -> IO (Rep t)
+-- readAcDense typ val = case typ of
+-- STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val)
+-- STEither t1 t2 -> case val of
+-- Left x -> Left <$> readAcSparse t1 x
+-- Right y -> Right <$> readAcSparse t2 y
+-- _ -> error "readAcDense: invalid dense type"
+
+accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s ()
+accumAddSparse typ prj ref idx val = case (typ, prj) of
+ (STNil, SAPHere) -> return ()
-accumAddSparse :: STy t -> SNat i -> RepAcSparse t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s ()
+ (STPair t1 t2, SAPHere) ->
+ case val of
+ Nothing -> return ()
+ Just (val1, val2) ->
+ AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1
+ <*> newAcSparse t2 SAPHere () val2)
+ (\(ac1, ac2) -> do unAcM $ accumAddSparse t1 SAPHere ac1 () val1
+ unAcM $ accumAddSparse t2 SAPHere ac2 () val2)
+ (STPair t1 t2, SAPFst prj') ->
+ AcM $ realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2)
+ (\(ac1, _) -> do unAcM $ accumAddSparse t1 prj' ac1 idx val)
+ (STPair t1 t2, SAPSnd prj') ->
+ AcM $ realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val)
+ (\(_, ac2) -> do unAcM $ accumAddSparse t2 prj' ac2 idx val)
+
+ (STEither t1 t2, SAPHere) -> _ ref val
+ (STEither t1 _, SAPLeft prj') -> _ ref idx val
+ (STEither _ t2, SAPRight prj') -> _ ref idx val
+
+ (STMaybe t1, SAPHere) -> _ ref val
+ (STMaybe t1, SAPJust prj') -> _ ref idx val
+
+ (STArr _ t1, SAPHere) -> _ ref val
+ (STArr n t, SAPArrIdx prj' _) -> _ ref idx val
+
+ (STScal sty, SAPHere) -> AcM $ case sty of
+ STI32 -> return ()
+ STI64 -> return ()
+ STF32 -> atomicModifyIORef' ref (\x -> (x + val, ()))
+ STF64 -> atomicModifyIORef' ref (\x -> (x + val, ()))
+ STBool -> return ()
+
+ (STAccum{}, _) -> error "Accumulators not allowed in source program"
+
+realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> IO ()) -> IO ()
+realiseMaybeSparse ref makeval modifyval =
+ -- Try modifying what's already in ref. The 'join' makes the snd
+ -- of the function's return value a _continuation_ that is run after
+ -- the critical section ends.
+ join $ atomicModifyIORef' ref $ \ac -> case ac of
+ -- Oops, ref's contents was still sparse. Have to initialise
+ -- it first, then try again.
+ Nothing -> (ac, do val <- makeval
+ join $ atomicModifyIORef' ref $ \ac' -> case ac' of
+ Nothing -> (Just val, return ())
+ Just val' -> (ac', modifyval val'))
+ -- Yep, ref already had a value in there, so we can just add
+ -- val' to it recursively.
+ Just val -> (ac, modifyval val)
+
+{-
accumAddSparse typ SZ ref () val = case typ of
STNil -> return ()
STPair t1 t2 -> AcM $ do
@@ -532,6 +574,7 @@ accumAddDense typ (SS dep) ref idx val = case typ of
(Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val')
_ -> error "Mismatched Either in accumAddDense either"
_ -> error "accumAddDense: invalid dense type"
+-}
numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index ac06915..f84f4e7 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -21,28 +21,25 @@ type family Rep t where
Rep (TMaybe t) = Maybe (Rep t)
Rep (TArr n t) = Array n (Rep t)
Rep (TScal sty) = ScalRep sty
- Rep (TAccum t) = RepAcSparse t
+ Rep (TAccum t) = RepAc t
--- Mutable, and has a zero. The zero may not be O(1), but RepAcSparse (D2 t) will have an O(1) zero.
-type family RepAcSparse t where
- RepAcSparse TNil = ()
- RepAcSparse (TPair a b) = IORef (RepAcSparse a, RepAcSparse b)
- RepAcSparse (TEither a b) = TypeError (Text "Non-sparse coproduct is not a monoid")
- RepAcSparse (TMaybe t) = IORef (Maybe (RepAcDense t)) -- allow the value to be dense, because the Maybe's zero can be used for the contents
+-- Mutable, represents D2 of t. Has an O(1) zero.
+type family RepAc t where
+ RepAc TNil = ()
+ RepAc (TPair a b) = IORef (Maybe (RepAc a, RepAc b))
+ RepAc (TEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b)))
+ RepAc (TMaybe t) = IORef (Maybe (RepAc t))
-- TODO: an empty array is invalid for a zero-dimensional array, so zero-dimensional arrays don't actually have an O(1) zero.
- RepAcSparse (TArr n t) = IORef (Array n (RepAcSparse t)) -- empty array is zero
- RepAcSparse (TScal sty) = IORef (ScalRep sty)
- RepAcSparse (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators")
+ RepAc (TArr n t) = IORef (Array n (RepAc t)) -- empty array is zero
+ RepAc (TScal sty) = RepAcScal sty
+ RepAc (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators")
--- Immutable, and does not necessarily have a zero.
-type family RepAcDense t where
- RepAcDense TNil = ()
- RepAcDense (TPair a b) = (RepAcSparse a, RepAcSparse b)
- RepAcDense (TEither a b) = Either (RepAcSparse a) (RepAcSparse b)
- -- RepAcDense (TMaybe t) = RepAcSparse (TMaybe t) -- ^ This can be optimised to TMaybe (RepAcSparse t), but that makes accumAddDense very hard to write. And in any case, we don't need it because D2 will not produce Maybe of Maybe.
- -- RepAcDense (TArr n t) = Array n (RepAcSparse t)
- -- RepAcDense (TScal sty) = ScalRep sty
- -- RepAcDense (TAccum t) = TypeError (Text "RepAcDense: Nested accumulators")
+type family RepAcScal t where
+ RepAcScal TI32 = ()
+ RepAcScal TI64 = ()
+ RepAcScal TF32 = IORef Float
+ RepAcScal TF64 = IORef Double
+ RepAcScal TBool = ()
newtype Value t = Value { unValue :: Rep t }
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 2177789..0aa7a66 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -14,6 +14,7 @@ module Simplify (
import Data.Function (fix)
import Data.Monoid (Any(..))
+import Data.Type.Equality (testEquality)
import AST
import AST.Count
@@ -105,10 +106,10 @@ simplify' = \case
EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e)))
-- projection down-commuting
- EFst _ (ECase _ e1 e2@EPair{} e3@EPair{}) ->
+ EFst _ (ECase _ e1 e2 e3) ->
acted $ simplify' $
ECase ext e1 (EFst ext e2) (EFst ext e3)
- ESnd _ (ECase _ e1 e2@EPair{} e3@EPair{}) ->
+ ESnd _ (ECase _ e1 e2 e3) ->
acted $ simplify' $
ECase ext e1 (ESnd ext e2) (ESnd ext e3)
@@ -118,16 +119,22 @@ simplify' = \case
-- TODO: constant folding for operations
- -- TODO: properly concatenate accum/onehot
- EAccum _ SZ _ (EOneHot _ _ i idx val) acc ->
- acted $ simplify' $
- EAccum ext i idx val acc
- EAccum _ _ _ (EZero _ _) _ -> (Any True, ENil ext)
+ -- monoid rules
+ EAccum _ t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val) acc
+ | Just Refl <- testEquality (acPrjTy prj1 t1) t2
+ -> concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
+ acted $ simplify' (EAccum ext t1 prj12 idx12 val acc)
+ EAccum _ _ _ _ (EZero _ _) _ -> (Any True, ENil ext)
EPlus _ _ (EZero _ _) e -> acted $ simplify' e
EPlus _ _ e (EZero _ _) -> acted $ simplify' e
- EOneHot _ _ SZ _ e -> acted $ simplify' e
-
- -- equations for plus
+ EOneHot _ t _ _ (EZero _ _) -> (Any True, EZero ext t)
+ EOneHot _ _ SAPHere _ e -> acted $ simplify' e
+ EOneHot _ t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)
+ | Just Refl <- testEquality (acPrjTy prj1 t1) t2
+ -> concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
+ acted $ simplify' (EOneHot ext t1 prj12 idx12 val)
+
+ -- type-specific equations for plus
EPlus _ STNil _ _ -> (Any True, ENil ext)
EPlus _ (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) ->
@@ -180,8 +187,8 @@ simplify' = \case
<*> (let ?accumInScope = False in simplify' b)
<*> (let ?accumInScope = False in simplify' c)
<*> simplify' e1 <*> simplify' e2
- EWith _ e1 e2 -> EWith ext <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
- EAccum _ i e1 e2 e3 -> EAccum ext i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3
+ EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2)
+ EAccum _ t i e1 e2 e3 -> EAccum ext t i <$> simplify' e1 <*> simplify' e2 <*> simplify' e3
EZero _ t -> pure $ EZero ext t
EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b
EOneHot _ t i a b -> EOneHot ext t i <$> simplify' a <*> simplify' b
@@ -230,8 +237,8 @@ hasAdds = \case
EIdx _ a b -> hasAdds a || hasAdds b
EShape _ e -> hasAdds e
EOp _ _ e -> hasAdds e
- EWith _ a b -> hasAdds a || hasAdds b
- EAccum _ _ _ _ _ -> True
+ EWith _ _ a b -> hasAdds a || hasAdds b
+ EAccum _ _ _ _ _ _ -> True
EZero _ _ -> False
EPlus _ _ a b -> hasAdds a || hasAdds b
EOneHot _ _ _ a b -> hasAdds a || hasAdds b
@@ -249,3 +256,27 @@ checkAccumInScope = \case SNil -> False
check (STArr _ t) = check t
check (STScal _) = False
check STAccum{} = True
+
+concatOneHots :: STy a
+ -> SAcPrj p1 a b -> Ex env (AcIdx p1 a)
+ -> SAcPrj p2 b c -> Ex env (AcIdx p2 b)
+ -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r
+concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of
+ (_, SAPHere) -> k prj2 idx2
+
+ (STPair a _, SAPFst prj1') ->
+ concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPFst prj12) idx12
+ (STPair _ b, SAPSnd prj1') ->
+ concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPSnd prj12) idx12
+
+ (STEither a _, SAPLeft prj1') ->
+ concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12
+ (STEither _ b, SAPRight prj1') ->
+ concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12
+
+ (STMaybe a, SAPJust prj1') ->
+ concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12
+
+ (STArr n a, SAPArrIdx prj1' _) ->
+ concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
+ k (SAPArrIdx prj12 n) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)