diff options
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 136 |
1 files changed, 42 insertions, 94 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 7747d46..1ab2da0 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -37,6 +37,7 @@ import AST import AST.Count import AST.Env import AST.Weaken.Auto +import CHAD.Types import Data import Lemmas @@ -288,66 +289,12 @@ vectorise1Binds env n (bs `BPush` (t, e)) = (vectoriseExpr SNil (bindingsBinds bs) env e) in bs' `BPush` (STArr (SS SZ) t, e') -type family D1 t where - D1 TNil = TNil - D1 (TPair a b) = TPair (D1 a) (D1 b) - D1 (TEither a b) = TEither (D1 a) (D1 b) - D1 (TArr n t) = TArr n (D1 t) - D1 (TScal t) = TScal t - -type family D2 t where - D2 TNil = TNil - D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b)) - D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b)) - D2 (TArr n t) = TArr n (D2 t) - D2 (TScal t) = D2s t - -type family D2s t where - D2s TI32 = TNil - D2s TI64 = TNil - D2s TF32 = TScal TF32 - D2s TF64 = TScal TF64 - D2s TBool = TNil - -type family D1E env where - D1E '[] = '[] - D1E (t : env) = D1 t : D1E env - -type family D2E env where - D2E '[] = '[] - D2E (t : env) = D2 t : D2E env - -type family D2AcE env where - D2AcE '[] = '[] - D2AcE (t : env) = TAccum (D2 t) : D2AcE env - -- | Select only the types from the environment that have the specified storage type family Select env sto s where Select '[] '[] _ = '[] Select (t : ts) (s : sto) s = t : Select ts sto s Select (_ : ts) (_ : sto) s = Select ts sto s -d1 :: STy t -> STy (D1 t) -d1 STNil = STNil -d1 (STPair a b) = STPair (d1 a) (d1 b) -d1 (STEither a b) = STEither (d1 a) (d1 b) -d1 (STArr n t) = STArr n (d1 t) -d1 (STScal t) = STScal t -d1 STAccum{} = error "Accumulators not allowed in input program" - -d2 :: STy t -> STy (D2 t) -d2 STNil = STNil -d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b)) -d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b)) -d2 (STArr n t) = STArr n (d2 t) -d2 (STScal t) = case t of - STI32 -> STNil - STI64 -> STNil - STF32 -> STScal STF32 - STF64 -> STScal STF64 - STBool -> STNil -d2 STAccum{} = error "Accumulators not allowed in input program" - conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) conv1Idx IZ = IZ conv1Idx (IS i) = IS (conv1Idx i) @@ -362,48 +309,49 @@ conv2Idx (DPush des (_, SMerge)) (IS i) = second IS (conv2Idx des i) conv2Idx DTop i = case i of {} zero :: STy t -> Ex env (D2 t) -zero STNil = ENil ext -zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext) -zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext) -zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t) -zero (STScal t) = case t of - STI32 -> ENil ext - STI64 -> ENil ext - STF32 -> EConst ext STF32 0.0 - STF64 -> EConst ext STF64 0.0 - STBool -> ENil ext -zero STAccum{} = error "Accumulators not allowed in input program" +zero = EZero +-- TODO: this original definition needs to be used as the post-processing after +-- simplification, to eliminate the monoid operations from the AST +-- zero STNil = ENil ext +-- zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext) +-- zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext) +-- zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t) +-- zero (STScal t) = case t of +-- STI32 -> ENil ext +-- STI64 -> ENil ext +-- STF32 -> EConst ext STF32 0.0 +-- STF64 -> EConst ext STF64 0.0 +-- STBool -> ENil ext +-- zero STAccum{} = error "Accumulators not allowed in input program" plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t) -plus STNil _ _ = ENil ext -plus (STPair t1 t2) a b = - let t = STPair (d2 t1) (d2 t2) - in plusSparse t a b $ - EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) - (EFst ext (EVar ext t IZ))) - (plus t2 (ESnd ext (EVar ext t (IS IZ))) - (ESnd ext (EVar ext t IZ))) -plus (STEither t1 t2) a b = - let t = STEither (d2 t1) (d2 t2) - in plusSparse t a b $ - ECase ext (EVar ext t (IS IZ)) - (ECase ext (EVar ext t (IS IZ)) - (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ))) - (EError t "plus l+r")) - (ECase ext (EVar ext t (IS IZ)) - (EError t "plus r+l") - (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) -plus STArr{} _ _ = error "TODO plus on arrays" - -- 'zero' creates an empty array; this should be a new primitive that - -- (operationally) intelligently memcpy's the non-overlapping part and does - -- a parallel add on the overlapping part. -plus (STScal t) a b = case t of - STI32 -> ENil ext - STI64 -> ENil ext - STF32 -> EOp ext (OAdd STF32) (EPair ext a b) - STF64 -> EOp ext (OAdd STF64) (EPair ext a b) - STBool -> ENil ext -plus STAccum{} _ _ = error "Accumulators not allowed in input program" +plus = EPlus +-- plus STNil _ _ = ENil ext +-- plus (STPair t1 t2) a b = +-- let t = STPair (d2 t1) (d2 t2) +-- in plusSparse t a b $ +-- EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) +-- (EFst ext (EVar ext t IZ))) +-- (plus t2 (ESnd ext (EVar ext t (IS IZ))) +-- (ESnd ext (EVar ext t IZ))) +-- plus (STEither t1 t2) a b = +-- let t = STEither (d2 t1) (d2 t2) +-- in plusSparse t a b $ +-- ECase ext (EVar ext t (IS IZ)) +-- (ECase ext (EVar ext t (IS IZ)) +-- (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ))) +-- (EError t "plus l+r")) +-- (ECase ext (EVar ext t (IS IZ)) +-- (EError t "plus r+l") +-- (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) +-- plus STArr{} _ _ = error "TODO plus on arrays" +-- plus (STScal t) a b = case t of +-- STI32 -> ENil ext +-- STI64 -> ENil ext +-- STF32 -> EOp ext (OAdd STF32) (EPair ext a b) +-- STF64 -> EOp ext (OAdd STF64) (EPair ext a b) +-- STBool -> ENil ext +-- plus STAccum{} _ _ = error "Accumulators not allowed in input program" plusSparse :: STy a -> Ex env (TEither TNil a) -> Ex env (TEither TNil a) |