diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 47 |
1 files changed, 20 insertions, 27 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 087a26e..692bb96 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -309,9 +309,6 @@ type family D2s t where D2s TF64 = TScal TF64 D2s TBool = TNil -type family D2Ac t where - D2Ac (TArr n t) = TAccum n (D2 t) - type family D1E env where D1E '[] = '[] D1E (t : env) = D1 t : D1E env @@ -322,7 +319,7 @@ type family D2E env where type family D2AcE env where D2AcE '[] = '[] - D2AcE (t : env) = D2Ac t : D2AcE env + D2AcE (t : env) = TAccum (D2 t) : D2AcE env -- | Select only the types from the environment that have the specified storage type family Select env sto s where @@ -351,16 +348,13 @@ d2 (STScal t) = case t of STBool -> STNil d2 STAccum{} = error "Accumulators not allowed in input program" -d2ac :: STy t -> STy (D2Ac t) -d2ac (STArr n t) = STAccum n (d2 t) -d2ac _ = error "Only arrays may appear in the accumulator environment" - conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) conv1Idx IZ = IZ conv1Idx (IS i) = IS (conv1Idx i) -conv2Idx :: Descr env sto -> Idx env t -> Either (Idx (D2E (Select env sto "accum")) (D2 t)) - (Idx (Select env sto "merge") t) +conv2Idx :: Descr env sto -> Idx env t + -> Either (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) + (Idx (Select env sto "merge") t) conv2Idx (DPush _ (_, SAccum)) IZ = Left IZ conv2Idx (DPush _ (_, SMerge)) IZ = Right IZ conv2Idx (DPush des (_, SAccum)) (IS i) = first IS (conv2Idx des i) @@ -371,7 +365,7 @@ zero :: STy t -> Ex env (D2 t) zero STNil = ENil ext zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext) zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext) -zero (STArr n t) = EBuild ext (vecGenerate n (\_ -> EConst ext STI64 0)) (zero t) +zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t) zero (STScal t) = case t of STI32 -> ENil ext STI64 -> ENil ext @@ -464,11 +458,11 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = envpro prosub (\shbinds -> - autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum descr))) + autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) (#pro :++: #d :++: #shb :++: #acc :++: #tl) .> WCopy (wf shbinds) - .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum storepl))) + .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) (#d :++: #shb :++: #acc :++: #tl) (#acc :++: (#d :++: #shb :++: #tl))) @@ -489,7 +483,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = -- goal: | ARE EQUAL || -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) WCopy (wf shbinds) - .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) + .> WPick @(TAccum (D2 (TArr arrn arrt))) @(D2 dt : shbinds) (Const () `SCons` shbindsC) (WId @(D2AcE (Select env1 stoRepl "accum")))) -- "merge" values must be an array, so reject everything else. (TODO: generalise this) @@ -505,10 +499,6 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = -- STScal{} -> False -- STAccum{} -> error "An accumulator in merge storage?" -type family InvTup core env where - InvTup core '[] = core - InvTup core (t : ts) = InvTup (TPair core t) ts - makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) makeAccumulators SNil e = e makeAccumulators (STArr n t `SCons` envpro) e = @@ -753,7 +743,7 @@ d2e (SCons t ts) = SCons (d2 t) (d2e ts) d2ace :: SList STy env -> SList STy (D2AcE env) d2ace SNil = SNil -d2ace (SCons t ts) = SCons (d2ac t) (d2ace ts) +d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts) freezeRet :: Descr env sto -> Ret env sto t @@ -775,11 +765,11 @@ drev :: forall env sto t. drev des = \case EVar _ t i -> case conv2Idx des i of - Left _ -> + Left accI -> Ret BTop (EVar ext (d1 t) (conv1Idx i)) (subenvNone (select SMerge des)) - (ENil ext) + (EAccum SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI))) Right tupI -> Ret BTop @@ -1075,22 +1065,25 @@ drev des = \case -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil - -> - Ret binds - (EIdx1 ext e1 ei1) + , STArr (SS n) eltty <- typeOf e -> + Ret (binds `BPush` (tTup (sreplicate (SS n) tIx), EShape ext e1)) + (weakenExpr WSink (EIdx1 ext e1 ei1)) sub - (_ e2) + (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (EVar ext (STArr n (d2 eltty)) (IS IZ))) $ + weakenExpr (WCopy (WSink .> WSink)) e2) -- These should be the next to be implemented, I think EFold1{} -> err_unsupported "EFold1" EShape{} -> err_unsupported "EShape" - EReplicate{} -> err_unsupported "EReplicate" + -- EReplicate{} -> err_unsupported "EReplicate" EIdx{} -> err_unsupported "EIdx" EBuild{} -> err_unsupported "EBuild" EWith{} -> err_accum - EAccum1{} -> err_accum + EAccum{} -> err_accum where err_accum = error "Accumulator operations unsupported in the source program" |