summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs136
1 files changed, 42 insertions, 94 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 7747d46..1ab2da0 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -37,6 +37,7 @@ import AST
import AST.Count
import AST.Env
import AST.Weaken.Auto
+import CHAD.Types
import Data
import Lemmas
@@ -288,66 +289,12 @@ vectorise1Binds env n (bs `BPush` (t, e)) =
(vectoriseExpr SNil (bindingsBinds bs) env e)
in bs' `BPush` (STArr (SS SZ) t, e')
-type family D1 t where
- D1 TNil = TNil
- D1 (TPair a b) = TPair (D1 a) (D1 b)
- D1 (TEither a b) = TEither (D1 a) (D1 b)
- D1 (TArr n t) = TArr n (D1 t)
- D1 (TScal t) = TScal t
-
-type family D2 t where
- D2 TNil = TNil
- D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b))
- D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b))
- D2 (TArr n t) = TArr n (D2 t)
- D2 (TScal t) = D2s t
-
-type family D2s t where
- D2s TI32 = TNil
- D2s TI64 = TNil
- D2s TF32 = TScal TF32
- D2s TF64 = TScal TF64
- D2s TBool = TNil
-
-type family D1E env where
- D1E '[] = '[]
- D1E (t : env) = D1 t : D1E env
-
-type family D2E env where
- D2E '[] = '[]
- D2E (t : env) = D2 t : D2E env
-
-type family D2AcE env where
- D2AcE '[] = '[]
- D2AcE (t : env) = TAccum (D2 t) : D2AcE env
-
-- | Select only the types from the environment that have the specified storage
type family Select env sto s where
Select '[] '[] _ = '[]
Select (t : ts) (s : sto) s = t : Select ts sto s
Select (_ : ts) (_ : sto) s = Select ts sto s
-d1 :: STy t -> STy (D1 t)
-d1 STNil = STNil
-d1 (STPair a b) = STPair (d1 a) (d1 b)
-d1 (STEither a b) = STEither (d1 a) (d1 b)
-d1 (STArr n t) = STArr n (d1 t)
-d1 (STScal t) = STScal t
-d1 STAccum{} = error "Accumulators not allowed in input program"
-
-d2 :: STy t -> STy (D2 t)
-d2 STNil = STNil
-d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b))
-d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b))
-d2 (STArr n t) = STArr n (d2 t)
-d2 (STScal t) = case t of
- STI32 -> STNil
- STI64 -> STNil
- STF32 -> STScal STF32
- STF64 -> STScal STF64
- STBool -> STNil
-d2 STAccum{} = error "Accumulators not allowed in input program"
-
conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
conv1Idx IZ = IZ
conv1Idx (IS i) = IS (conv1Idx i)
@@ -362,48 +309,49 @@ conv2Idx (DPush des (_, SMerge)) (IS i) = second IS (conv2Idx des i)
conv2Idx DTop i = case i of {}
zero :: STy t -> Ex env (D2 t)
-zero STNil = ENil ext
-zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext)
-zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext)
-zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t)
-zero (STScal t) = case t of
- STI32 -> ENil ext
- STI64 -> ENil ext
- STF32 -> EConst ext STF32 0.0
- STF64 -> EConst ext STF64 0.0
- STBool -> ENil ext
-zero STAccum{} = error "Accumulators not allowed in input program"
+zero = EZero
+-- TODO: this original definition needs to be used as the post-processing after
+-- simplification, to eliminate the monoid operations from the AST
+-- zero STNil = ENil ext
+-- zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext)
+-- zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext)
+-- zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t)
+-- zero (STScal t) = case t of
+-- STI32 -> ENil ext
+-- STI64 -> ENil ext
+-- STF32 -> EConst ext STF32 0.0
+-- STF64 -> EConst ext STF64 0.0
+-- STBool -> ENil ext
+-- zero STAccum{} = error "Accumulators not allowed in input program"
plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
-plus STNil _ _ = ENil ext
-plus (STPair t1 t2) a b =
- let t = STPair (d2 t1) (d2 t2)
- in plusSparse t a b $
- EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
- (EFst ext (EVar ext t IZ)))
- (plus t2 (ESnd ext (EVar ext t (IS IZ)))
- (ESnd ext (EVar ext t IZ)))
-plus (STEither t1 t2) a b =
- let t = STEither (d2 t1) (d2 t2)
- in plusSparse t a b $
- ECase ext (EVar ext t (IS IZ))
- (ECase ext (EVar ext t (IS IZ))
- (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
- (EError t "plus l+r"))
- (ECase ext (EVar ext t (IS IZ))
- (EError t "plus r+l")
- (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
-plus STArr{} _ _ = error "TODO plus on arrays"
- -- 'zero' creates an empty array; this should be a new primitive that
- -- (operationally) intelligently memcpy's the non-overlapping part and does
- -- a parallel add on the overlapping part.
-plus (STScal t) a b = case t of
- STI32 -> ENil ext
- STI64 -> ENil ext
- STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
- STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
- STBool -> ENil ext
-plus STAccum{} _ _ = error "Accumulators not allowed in input program"
+plus = EPlus
+-- plus STNil _ _ = ENil ext
+-- plus (STPair t1 t2) a b =
+-- let t = STPair (d2 t1) (d2 t2)
+-- in plusSparse t a b $
+-- EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
+-- (EFst ext (EVar ext t IZ)))
+-- (plus t2 (ESnd ext (EVar ext t (IS IZ)))
+-- (ESnd ext (EVar ext t IZ)))
+-- plus (STEither t1 t2) a b =
+-- let t = STEither (d2 t1) (d2 t2)
+-- in plusSparse t a b $
+-- ECase ext (EVar ext t (IS IZ))
+-- (ECase ext (EVar ext t (IS IZ))
+-- (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
+-- (EError t "plus l+r"))
+-- (ECase ext (EVar ext t (IS IZ))
+-- (EError t "plus r+l")
+-- (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
+-- plus STArr{} _ _ = error "TODO plus on arrays"
+-- plus (STScal t) a b = case t of
+-- STI32 -> ENil ext
+-- STI64 -> ENil ext
+-- STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
+-- STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
+-- STBool -> ENil ext
+-- plus STAccum{} _ _ = error "Accumulators not allowed in input program"
plusSparse :: STy a
-> Ex env (TEither TNil a) -> Ex env (TEither TNil a)