summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs47
1 files changed, 20 insertions, 27 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 087a26e..692bb96 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -309,9 +309,6 @@ type family D2s t where
D2s TF64 = TScal TF64
D2s TBool = TNil
-type family D2Ac t where
- D2Ac (TArr n t) = TAccum n (D2 t)
-
type family D1E env where
D1E '[] = '[]
D1E (t : env) = D1 t : D1E env
@@ -322,7 +319,7 @@ type family D2E env where
type family D2AcE env where
D2AcE '[] = '[]
- D2AcE (t : env) = D2Ac t : D2AcE env
+ D2AcE (t : env) = TAccum (D2 t) : D2AcE env
-- | Select only the types from the environment that have the specified storage
type family Select env sto s where
@@ -351,16 +348,13 @@ d2 (STScal t) = case t of
STBool -> STNil
d2 STAccum{} = error "Accumulators not allowed in input program"
-d2ac :: STy t -> STy (D2Ac t)
-d2ac (STArr n t) = STAccum n (d2 t)
-d2ac _ = error "Only arrays may appear in the accumulator environment"
-
conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
conv1Idx IZ = IZ
conv1Idx (IS i) = IS (conv1Idx i)
-conv2Idx :: Descr env sto -> Idx env t -> Either (Idx (D2E (Select env sto "accum")) (D2 t))
- (Idx (Select env sto "merge") t)
+conv2Idx :: Descr env sto -> Idx env t
+ -> Either (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t)))
+ (Idx (Select env sto "merge") t)
conv2Idx (DPush _ (_, SAccum)) IZ = Left IZ
conv2Idx (DPush _ (_, SMerge)) IZ = Right IZ
conv2Idx (DPush des (_, SAccum)) (IS i) = first IS (conv2Idx des i)
@@ -371,7 +365,7 @@ zero :: STy t -> Ex env (D2 t)
zero STNil = ENil ext
zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext)
zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext)
-zero (STArr n t) = EBuild ext (vecGenerate n (\_ -> EConst ext STI64 0)) (zero t)
+zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t)
zero (STScal t) = case t of
STI32 -> ENil ext
STI64 -> ENil ext
@@ -464,11 +458,11 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
envpro
prosub
(\shbinds ->
- autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum descr)))
+ autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))
(#acc :++: (#pro :++: #d :++: #shb :++: #tl))
(#pro :++: #d :++: #shb :++: #acc :++: #tl)
.> WCopy (wf shbinds)
- .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum storepl)))
+ .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl)))
(#d :++: #shb :++: #acc :++: #tl)
(#acc :++: (#d :++: #shb :++: #tl)))
@@ -489,7 +483,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
-- goal: | ARE EQUAL ||
-- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
WCopy (wf shbinds)
- .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
+ .> WPick @(TAccum (D2 (TArr arrn arrt))) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
(WId @(D2AcE (Select env1 stoRepl "accum"))))
-- "merge" values must be an array, so reject everything else. (TODO: generalise this)
@@ -505,10 +499,6 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
-- STScal{} -> False
-- STAccum{} -> error "An accumulator in merge storage?"
-type family InvTup core env where
- InvTup core '[] = core
- InvTup core (t : ts) = InvTup (TPair core t) ts
-
makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
makeAccumulators SNil e = e
makeAccumulators (STArr n t `SCons` envpro) e =
@@ -753,7 +743,7 @@ d2e (SCons t ts) = SCons (d2 t) (d2e ts)
d2ace :: SList STy env -> SList STy (D2AcE env)
d2ace SNil = SNil
-d2ace (SCons t ts) = SCons (d2ac t) (d2ace ts)
+d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts)
freezeRet :: Descr env sto
-> Ret env sto t
@@ -775,11 +765,11 @@ drev :: forall env sto t.
drev des = \case
EVar _ t i ->
case conv2Idx des i of
- Left _ ->
+ Left accI ->
Ret BTop
(EVar ext (d1 t) (conv1Idx i))
(subenvNone (select SMerge des))
- (ENil ext)
+ (EAccum SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI)))
Right tupI ->
Ret BTop
@@ -1075,22 +1065,25 @@ drev des = \case
-- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
| Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
<- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil
- ->
- Ret binds
- (EIdx1 ext e1 ei1)
+ , STArr (SS n) eltty <- typeOf e ->
+ Ret (binds `BPush` (tTup (sreplicate (SS n) tIx), EShape ext e1))
+ (weakenExpr WSink (EIdx1 ext e1 ei1))
sub
- (_ e2)
+ (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (EVar ext (STArr n (d2 eltty)) (IS IZ))) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
-- These should be the next to be implemented, I think
EFold1{} -> err_unsupported "EFold1"
EShape{} -> err_unsupported "EShape"
- EReplicate{} -> err_unsupported "EReplicate"
+ -- EReplicate{} -> err_unsupported "EReplicate"
EIdx{} -> err_unsupported "EIdx"
EBuild{} -> err_unsupported "EBuild"
EWith{} -> err_accum
- EAccum1{} -> err_accum
+ EAccum{} -> err_accum
where
err_accum = error "Accumulator operations unsupported in the source program"