summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chad-fast.cabal4
-rw-r--r--src/AST.hs109
-rw-r--r--src/AST/Count.hs3
-rw-r--r--src/AST/Pretty.hs7
-rw-r--r--src/AST/Types.hs95
-rw-r--r--src/Array.hs15
-rw-r--r--src/CHAD.hs136
-rw-r--r--src/CHAD/Types.hs65
-rw-r--r--src/Interpreter.hs142
-rw-r--r--src/Interpreter/Accum.hs26
-rw-r--r--src/Interpreter/AccumOld.hs366
-rw-r--r--src/Interpreter/Rep.hs20
-rw-r--r--src/Simplify.hs6
13 files changed, 756 insertions, 238 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 6314acd..1b95c66 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -15,14 +15,16 @@ library
AST.Count
AST.Env
AST.Pretty
+ AST.Types
AST.Weaken
AST.Weaken.Auto
CHAD
+ CHAD.Types
-- Compile
Data
Example
Interpreter
- Interpreter.Accum
+ -- Interpreter.AccumOld
Interpreter.Rep
Language
Language.AST
diff --git a/src/AST.hs b/src/AST.hs
index 2132bc6..ed2039b 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -13,92 +13,19 @@
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-module AST (module AST, module AST.Weaken) where
+module AST (module AST, module AST.Types, module AST.Weaken) where
import Data.Functor.Const
import Data.Kind (Type)
-import Data.Int
-import Data.Type.Equality
import Array
import AST.Env
+import AST.Types
import AST.Weaken
+import CHAD.Types
import Data
-data Ty
- = TNil
- | TPair Ty Ty
- | TEither Ty Ty
- | TArr Nat Ty -- ^ rank, element type
- | TScal ScalTy
- | TAccum Ty
- deriving (Show, Eq, Ord)
-
-data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
- deriving (Show, Eq, Ord)
-
-type STy :: Ty -> Type
-data STy t where
- STNil :: STy TNil
- STPair :: STy a -> STy b -> STy (TPair a b)
- STEither :: STy a -> STy b -> STy (TEither a b)
- STArr :: SNat n -> STy t -> STy (TArr n t)
- STScal :: SScalTy t -> STy (TScal t)
- STAccum :: STy t -> STy (TAccum t)
-deriving instance Show (STy t)
-
-instance TestEquality STy where
- testEquality STNil STNil = Just Refl
- testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
- testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
- testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
- testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl
- testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl
- testEquality _ _ = Nothing
-
-data SScalTy t where
- STI32 :: SScalTy TI32
- STI64 :: SScalTy TI64
- STF32 :: SScalTy TF32
- STF64 :: SScalTy TF64
- STBool :: SScalTy TBool
-deriving instance Show (SScalTy t)
-
-instance TestEquality SScalTy where
- testEquality STI32 STI32 = Just Refl
- testEquality STI64 STI64 = Just Refl
- testEquality STF32 STF32 = Just Refl
- testEquality STF64 STF64 = Just Refl
- testEquality STBool STBool = Just Refl
- testEquality _ _ = Nothing
-
-scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))
-scalRepIsShow STI32 = Dict
-scalRepIsShow STI64 = Dict
-scalRepIsShow STF32 = Dict
-scalRepIsShow STF64 = Dict
-scalRepIsShow STBool = Dict
-
-type TIx = TScal TI64
-
-tIx :: STy TIx
-tIx = STScal STI64
-
-type family ScalRep t where
- ScalRep TI32 = Int32
- ScalRep TI64 = Int64
- ScalRep TF32 = Float
- ScalRep TF64 = Double
- ScalRep TBool = Bool
-
-type family ScalIsNumeric t where
- ScalIsNumeric TI32 = True
- ScalIsNumeric TI64 = True
- ScalIsNumeric TF32 = True
- ScalIsNumeric TF64 = True
- ScalIsNumeric TBool = False
-
-- | This index is flipped around from the usual direction: the smallest index
-- is at the _heart_ of the nesting, not at the outside. The outermost layer
-- indexes into the _outer_ dimension of the type @t@. This makes indices into
@@ -107,6 +34,7 @@ type family AcIdx t i where
AcIdx t Z = TNil
AcIdx (TPair a b) (S i) = TEither (AcIdx a i) (AcIdx b i)
AcIdx (TEither a b) (S i) = TEither (AcIdx a i) (AcIdx b i)
+ AcIdx (TMaybe t) (S i) = AcIdx t i
AcIdx (TArr Z t) (S i) = AcIdx t i
AcIdx (TArr (S n) t) (S i) = TPair TIx (AcIdx (TArr n t) i)
@@ -114,12 +42,20 @@ type family AcVal t i where
AcVal t Z = t
AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i)
AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i)
+ AcVal (TMaybe t) (S i) = AcVal t i
AcVal (TArr Z t) (S i) = AcVal t i
AcVal (TArr (S n) t) (S i) = AcVal (TArr n t) i
-- General assumption: head of the list (whatever way it is associated) is the
-- inner variable / inner array dimension. In pretty printing, the inner
-- variable / inner dimension is printed on the _right_.
+--
+-- Note that the 'EZero' and 'EPlus' constructs have typing that depend on the
+-- type transformation of CHAD. Indeed, these constructors are created _by_
+-- CHAD, and are intended to be eliminated after simplification, so that the
+-- input program as well as the output program do not contain these
+-- constructors.
+-- TODO: ensure this by a "stage" type parameter.
type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type
data Expr x env t where
-- lambda calculus
@@ -134,6 +70,9 @@ data Expr x env t where
EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b)
EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b)
ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
+ ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t)
+ EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t)
+ EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b
-- array operations
EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
@@ -157,6 +96,10 @@ data Expr x env t where
EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil
-- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil
+ -- monoidal operations (to be desugared to regular operations after simplification)
+ EZero :: STy t -> Expr x env (D2 t)
+ EPlus :: STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t)
+
-- partiality
EError :: STy a -> String -> Expr x env a
deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)
@@ -220,6 +163,9 @@ typeOf = \case
EInl _ t2 e -> STEither (typeOf e) t2
EInr _ t1 e -> STEither t1 (typeOf e)
ECase _ _ a _ -> typeOf a
+ ENothing _ t -> STMaybe t
+ EJust _ e -> STMaybe (typeOf e)
+ EMaybe _ e _ _ -> typeOf e
EConstArr _ n t _ -> STArr n (STScal t)
EBuild1 _ _ e -> STArr (SS SZ) (typeOf e)
@@ -239,6 +185,9 @@ typeOf = \case
EWith e1 e2 -> STPair (typeOf e2) (typeOf e1)
EAccum _ _ _ _ -> STNil
+ EZero t -> d2 t
+ EPlus t _ _ -> d2 t
+
EError t _ -> t
unSNat :: SNat n -> Nat
@@ -250,6 +199,7 @@ unSTy = \case
STNil -> TNil
STPair a b -> TPair (unSTy a) (unSTy b)
STEither a b -> TEither (unSTy a) (unSTy b)
+ STMaybe t -> TMaybe (unSTy t)
STArr n t -> TArr (unSNat n) (unSTy t)
STScal t -> TScal (unSScalTy t)
STAccum t -> TAccum (unSTy t)
@@ -288,6 +238,9 @@ subst' f w = \case
EInl x t e -> EInl x t (subst' f w e)
EInr x t e -> EInr x t (subst' f w e)
ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
+ ENothing x t -> ENothing x t
+ EJust x e -> EJust x (subst' f w e)
+ EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e)
EConstArr x n t a -> EConstArr x n t a
EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)
EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
@@ -303,6 +256,8 @@ subst' f w = \case
EOp x op e -> EOp x op (subst' f w e)
EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3)
+ EZero t -> EZero t
+ EPlus t a b -> EPlus t (subst' f w a) (subst' f w b)
EError t s -> EError t s
where
sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
@@ -339,6 +294,7 @@ class KnownTy t where knownTy :: STy t
instance KnownTy TNil where knownTy = STNil
instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy
instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
+instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy
instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy
@@ -351,6 +307,7 @@ styKnown :: STy t -> Dict (KnownTy t)
styKnown STNil = Dict
styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+styKnown (STMaybe t) | Dict <- styKnown t = Dict
styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
styKnown (STAccum t) | Dict <- styKnown t = Dict
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 6a00e83..40a46f6 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -46,6 +46,7 @@ Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
-- | This code is executed many times
scaleMany :: Occ -> Occ
+scaleMany (Occ l Zero) = Occ l Zero
scaleMany (Occ l _) = Occ l Many
occCount :: Idx env a -> Expr x env t -> Occ
@@ -124,6 +125,8 @@ occCountGeneral onehot unpush alter many = go WId
EOp _ _ e -> re e
EWith a b -> re a <> re1 b
EAccum _ a b e -> re a <> re b <> re e
+ EZero _ -> mempty
+ EPlus _ a b -> re a <> re b
EError{} -> mempty
where
re :: Monoid (r env') => Expr x env' t'' -> r env'
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 2ce883b..f5e681a 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -181,6 +181,13 @@ ppExpr' d val = \case
return $ showParen (d > 10) $
showString ("accum " ++ show (unSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3'
+ EZero _ -> return $ showString "zero"
+
+ EPlus _ a b -> do
+ a' <- ppExpr' 11 val a
+ b' <- ppExpr' 11 val b
+ return $ showParen (d > 10) $ showString "plus " . a' . showString " " . b'
+
EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
new file mode 100644
index 0000000..a3e5080
--- /dev/null
+++ b/src/AST/Types.hs
@@ -0,0 +1,95 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeFamilies #-}
+module AST.Types where
+
+import Data.Int (Int32, Int64)
+import Data.Kind (Type)
+import Data.Type.Equality
+
+import Data
+
+
+data Ty
+ = TNil
+ | TPair Ty Ty
+ | TEither Ty Ty
+ | TMaybe Ty
+ | TArr Nat Ty -- ^ rank, element type
+ | TScal ScalTy
+ | TAccum Ty
+ deriving (Show, Eq, Ord)
+
+data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
+ deriving (Show, Eq, Ord)
+
+type STy :: Ty -> Type
+data STy t where
+ STNil :: STy TNil
+ STPair :: STy a -> STy b -> STy (TPair a b)
+ STEither :: STy a -> STy b -> STy (TEither a b)
+ STMaybe :: STy a -> STy (TMaybe a)
+ STArr :: SNat n -> STy t -> STy (TArr n t)
+ STScal :: SScalTy t -> STy (TScal t)
+ STAccum :: STy t -> STy (TAccum t)
+deriving instance Show (STy t)
+
+instance TestEquality STy where
+ testEquality STNil STNil = Just Refl
+ testEquality STNil _ = Nothing
+ testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
+ testEquality STPair{} _ = Nothing
+ testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
+ testEquality STEither{} _ = Nothing
+ testEquality (STMaybe a) (STMaybe a') | Just Refl <- testEquality a a' = Just Refl
+ testEquality STMaybe{} _ = Nothing
+ testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
+ testEquality STArr{} _ = Nothing
+ testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl
+ testEquality STScal{} _ = Nothing
+ testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl
+ testEquality STAccum{} _ = Nothing
+
+data SScalTy t where
+ STI32 :: SScalTy TI32
+ STI64 :: SScalTy TI64
+ STF32 :: SScalTy TF32
+ STF64 :: SScalTy TF64
+ STBool :: SScalTy TBool
+deriving instance Show (SScalTy t)
+
+instance TestEquality SScalTy where
+ testEquality STI32 STI32 = Just Refl
+ testEquality STI64 STI64 = Just Refl
+ testEquality STF32 STF32 = Just Refl
+ testEquality STF64 STF64 = Just Refl
+ testEquality STBool STBool = Just Refl
+ testEquality _ _ = Nothing
+
+scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))
+scalRepIsShow STI32 = Dict
+scalRepIsShow STI64 = Dict
+scalRepIsShow STF32 = Dict
+scalRepIsShow STF64 = Dict
+scalRepIsShow STBool = Dict
+
+type TIx = TScal TI64
+
+tIx :: STy TIx
+tIx = STScal STI64
+
+type family ScalRep t where
+ ScalRep TI32 = Int32
+ ScalRep TI64 = Int64
+ ScalRep TF32 = Float
+ ScalRep TF64 = Double
+ ScalRep TBool = Bool
+
+type family ScalIsNumeric t where
+ ScalIsNumeric TI32 = True
+ ScalIsNumeric TI64 = True
+ ScalIsNumeric TF32 = True
+ ScalIsNumeric TF64 = True
+ ScalIsNumeric TBool = False
diff --git a/src/Array.hs b/src/Array.hs
index 9a770c4..0d585a9 100644
--- a/src/Array.hs
+++ b/src/Array.hs
@@ -17,11 +17,13 @@ data Shape n where
ShNil :: Shape Z
ShCons :: Shape n -> Int -> Shape (S n)
deriving instance Show (Shape n)
+deriving instance Eq (Shape n)
data Index n where
IxNil :: Index Z
IxCons :: Index n -> Int -> Index (S n)
deriving instance Show (Index n)
+deriving instance Eq (Index n)
shapeSize :: Shape n -> Int
shapeSize ShNil = 0
@@ -38,6 +40,10 @@ toLinearIndex :: Shape n -> Index n -> Int
toLinearIndex ShNil IxNil = 0
toLinearIndex (sh `ShCons` n) (idx `IxCons` i) = toLinearIndex sh idx * n + i
+emptyShape :: SNat n -> Shape n
+emptyShape SZ = ShNil
+emptyShape (SS m) = emptyShape m `ShCons` 0
+
-- | TODO: this Vector is a boxed vector, which is horrendously inefficient.
data Array (n :: Nat) t = Array (Shape n) (Vector t)
@@ -49,6 +55,9 @@ arrayShape (Array sh _) = sh
arraySize :: Array n t -> Int
arraySize (Array sh _) = shapeSize sh
+emptyArray :: SNat n -> Array n t
+emptyArray n = Array (emptyShape n) V.empty
+
arrayIndex :: Array n t -> Index n -> t
arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx)
@@ -58,6 +67,12 @@ arrayIndexLinear (Array _ v) i = v V.! i
arrayIndex1 :: Array (S n) t -> Int -> Array n t
arrayIndex1 (Array (sh `ShCons` _) v) i = let sz = shapeSize sh in Array sh (V.slice (sz * i) sz v)
+arrayGenerate :: Shape n -> (Index n -> t) -> Array n t
+arrayGenerate sh f = arrayGenerateLin sh (f . fromLinearIndex sh)
+
+arrayGenerateLin :: Shape n -> (Int -> t) -> Array n t
+arrayGenerateLin sh f = Array sh (V.generate (shapeSize sh) f)
+
arrayGenerateM :: Monad m => Shape n -> (Index n -> m t) -> m (Array n t)
arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 7747d46..1ab2da0 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -37,6 +37,7 @@ import AST
import AST.Count
import AST.Env
import AST.Weaken.Auto
+import CHAD.Types
import Data
import Lemmas
@@ -288,66 +289,12 @@ vectorise1Binds env n (bs `BPush` (t, e)) =
(vectoriseExpr SNil (bindingsBinds bs) env e)
in bs' `BPush` (STArr (SS SZ) t, e')
-type family D1 t where
- D1 TNil = TNil
- D1 (TPair a b) = TPair (D1 a) (D1 b)
- D1 (TEither a b) = TEither (D1 a) (D1 b)
- D1 (TArr n t) = TArr n (D1 t)
- D1 (TScal t) = TScal t
-
-type family D2 t where
- D2 TNil = TNil
- D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b))
- D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b))
- D2 (TArr n t) = TArr n (D2 t)
- D2 (TScal t) = D2s t
-
-type family D2s t where
- D2s TI32 = TNil
- D2s TI64 = TNil
- D2s TF32 = TScal TF32
- D2s TF64 = TScal TF64
- D2s TBool = TNil
-
-type family D1E env where
- D1E '[] = '[]
- D1E (t : env) = D1 t : D1E env
-
-type family D2E env where
- D2E '[] = '[]
- D2E (t : env) = D2 t : D2E env
-
-type family D2AcE env where
- D2AcE '[] = '[]
- D2AcE (t : env) = TAccum (D2 t) : D2AcE env
-
-- | Select only the types from the environment that have the specified storage
type family Select env sto s where
Select '[] '[] _ = '[]
Select (t : ts) (s : sto) s = t : Select ts sto s
Select (_ : ts) (_ : sto) s = Select ts sto s
-d1 :: STy t -> STy (D1 t)
-d1 STNil = STNil
-d1 (STPair a b) = STPair (d1 a) (d1 b)
-d1 (STEither a b) = STEither (d1 a) (d1 b)
-d1 (STArr n t) = STArr n (d1 t)
-d1 (STScal t) = STScal t
-d1 STAccum{} = error "Accumulators not allowed in input program"
-
-d2 :: STy t -> STy (D2 t)
-d2 STNil = STNil
-d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b))
-d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b))
-d2 (STArr n t) = STArr n (d2 t)
-d2 (STScal t) = case t of
- STI32 -> STNil
- STI64 -> STNil
- STF32 -> STScal STF32
- STF64 -> STScal STF64
- STBool -> STNil
-d2 STAccum{} = error "Accumulators not allowed in input program"
-
conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
conv1Idx IZ = IZ
conv1Idx (IS i) = IS (conv1Idx i)
@@ -362,48 +309,49 @@ conv2Idx (DPush des (_, SMerge)) (IS i) = second IS (conv2Idx des i)
conv2Idx DTop i = case i of {}
zero :: STy t -> Ex env (D2 t)
-zero STNil = ENil ext
-zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext)
-zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext)
-zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t)
-zero (STScal t) = case t of
- STI32 -> ENil ext
- STI64 -> ENil ext
- STF32 -> EConst ext STF32 0.0
- STF64 -> EConst ext STF64 0.0
- STBool -> ENil ext
-zero STAccum{} = error "Accumulators not allowed in input program"
+zero = EZero
+-- TODO: this original definition needs to be used as the post-processing after
+-- simplification, to eliminate the monoid operations from the AST
+-- zero STNil = ENil ext
+-- zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext)
+-- zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext)
+-- zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t)
+-- zero (STScal t) = case t of
+-- STI32 -> ENil ext
+-- STI64 -> ENil ext
+-- STF32 -> EConst ext STF32 0.0
+-- STF64 -> EConst ext STF64 0.0
+-- STBool -> ENil ext
+-- zero STAccum{} = error "Accumulators not allowed in input program"
plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
-plus STNil _ _ = ENil ext
-plus (STPair t1 t2) a b =
- let t = STPair (d2 t1) (d2 t2)
- in plusSparse t a b $
- EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
- (EFst ext (EVar ext t IZ)))
- (plus t2 (ESnd ext (EVar ext t (IS IZ)))
- (ESnd ext (EVar ext t IZ)))
-plus (STEither t1 t2) a b =
- let t = STEither (d2 t1) (d2 t2)
- in plusSparse t a b $
- ECase ext (EVar ext t (IS IZ))
- (ECase ext (EVar ext t (IS IZ))
- (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
- (EError t "plus l+r"))
- (ECase ext (EVar ext t (IS IZ))
- (EError t "plus r+l")
- (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
-plus STArr{} _ _ = error "TODO plus on arrays"
- -- 'zero' creates an empty array; this should be a new primitive that
- -- (operationally) intelligently memcpy's the non-overlapping part and does
- -- a parallel add on the overlapping part.
-plus (STScal t) a b = case t of
- STI32 -> ENil ext
- STI64 -> ENil ext
- STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
- STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
- STBool -> ENil ext
-plus STAccum{} _ _ = error "Accumulators not allowed in input program"
+plus = EPlus
+-- plus STNil _ _ = ENil ext
+-- plus (STPair t1 t2) a b =
+-- let t = STPair (d2 t1) (d2 t2)
+-- in plusSparse t a b $
+-- EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
+-- (EFst ext (EVar ext t IZ)))
+-- (plus t2 (ESnd ext (EVar ext t (IS IZ)))
+-- (ESnd ext (EVar ext t IZ)))
+-- plus (STEither t1 t2) a b =
+-- let t = STEither (d2 t1) (d2 t2)
+-- in plusSparse t a b $
+-- ECase ext (EVar ext t (IS IZ))
+-- (ECase ext (EVar ext t (IS IZ))
+-- (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
+-- (EError t "plus l+r"))
+-- (ECase ext (EVar ext t (IS IZ))
+-- (EError t "plus r+l")
+-- (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
+-- plus STArr{} _ _ = error "TODO plus on arrays"
+-- plus (STScal t) a b = case t of
+-- STI32 -> ENil ext
+-- STI64 -> ENil ext
+-- STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
+-- STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
+-- STBool -> ENil ext
+-- plus STAccum{} _ _ = error "Accumulators not allowed in input program"
plusSparse :: STy a
-> Ex env (TEither TNil a) -> Ex env (TEither TNil a)
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
new file mode 100644
index 0000000..0b32393
--- /dev/null
+++ b/src/CHAD/Types.hs
@@ -0,0 +1,65 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD.Types where
+
+import AST.Types
+
+
+type family D1 t where
+ D1 TNil = TNil
+ D1 (TPair a b) = TPair (D1 a) (D1 b)
+ D1 (TEither a b) = TEither (D1 a) (D1 b)
+ D1 (TMaybe a) = TMaybe (D1 a)
+ D1 (TArr n t) = TArr n (D1 t)
+ D1 (TScal t) = TScal t
+
+type family D2 t where
+ D2 TNil = TNil
+ D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b))
+ D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b))
+ D2 (TMaybe t) = TMaybe (D2 t)
+ D2 (TArr n t) = TArr n (D2 t)
+ D2 (TScal t) = D2s t
+
+type family D2s t where
+ D2s TI32 = TNil
+ D2s TI64 = TNil
+ D2s TF32 = TScal TF32
+ D2s TF64 = TScal TF64
+ D2s TBool = TNil
+
+type family D1E env where
+ D1E '[] = '[]
+ D1E (t : env) = D1 t : D1E env
+
+type family D2E env where
+ D2E '[] = '[]
+ D2E (t : env) = D2 t : D2E env
+
+type family D2AcE env where
+ D2AcE '[] = '[]
+ D2AcE (t : env) = TAccum (D2 t) : D2AcE env
+
+d1 :: STy t -> STy (D1 t)
+d1 STNil = STNil
+d1 (STPair a b) = STPair (d1 a) (d1 b)
+d1 (STEither a b) = STEither (d1 a) (d1 b)
+d1 (STMaybe t) = STMaybe (d1 t)
+d1 (STArr n t) = STArr n (d1 t)
+d1 (STScal t) = STScal t
+d1 STAccum{} = error "Accumulators not allowed in input program"
+
+d2 :: STy t -> STy (D2 t)
+d2 STNil = STNil
+d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b))
+d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b))
+d2 (STMaybe t) = STMaybe (d2 t)
+d2 (STArr n t) = STArr n (d2 t)
+d2 (STScal t) = case t of
+ STI32 -> STNil
+ STI64 -> STNil
+ STF32 -> STScal STF32
+ STF64 -> STScal STF64
+ STBool -> STNil
+d2 STAccum{} = error "Accumulators not allowed in input program"
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 7ffb14c..8728ec0 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -1,39 +1,44 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Interpreter (
interpret,
interpret',
Value,
- NoAccum(..),
- unAccum,
) where
+import Control.Monad (foldM)
import Data.Int (Int64)
import Data.Proxy
+import System.IO.Unsafe (unsafePerformIO)
+import Array
import AST
+import CHAD.Types
import Data
-import Array
-import Interpreter.Accum
import Interpreter.Rep
-import Control.Monad (foldM)
-interpret :: NoAccum t => Ex '[] t -> Rep t
-interpret e = runAcM (go e)
- where
- go :: forall s t. NoAccum t => Ex '[] t -> AcM s (Rep t)
- go e' | Refl <- noAccum (Proxy @s) (Proxy @t) = interpret' SNil e'
+newtype AcM s a = AcM (IO a)
+ deriving newtype (Functor, Applicative, Monad)
-newtype Value s t = Value (Rep' s t)
+runAcM :: (forall s. AcM s a) -> a
+runAcM (AcM m) = unsafePerformIO m
-interpret' :: forall env t s. SList (Value s) env -> Ex env t -> AcM s (Rep' s t)
+interpret :: Ex '[] t -> Rep t
+interpret e = runAcM (interpret' SNil e)
+
+newtype Value t = Value (Rep t)
+
+interpret' :: forall env t s. SList Value env -> Ex env t -> AcM s (Rep t)
interpret' env = \case
EVar _ _ i -> case slistIdx env i of Value x -> return x
ELet _ a b -> do
@@ -48,14 +53,17 @@ interpret' env = \case
ECase _ e a b -> interpret' env e >>= \case
Left x -> interpret' (Value x `SCons` env) a
Right y -> interpret' (Value y `SCons` env) b
+ ENothing _ _ -> _
+ EJust _ _ -> _
+ EMaybe _ _ _ _ -> _
EConstArr _ _ _ v -> return v
EBuild1 _ a b -> do
n <- fromIntegral @Int64 @Int <$> interpret' env a
arrayGenerateLinM (ShNil `ShCons` n)
(\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b)
EBuild _ dim a b -> do
- sh <- unTupRepIdx (Proxy @s) ShNil ShCons dim <$> interpret' env a
- arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx (Proxy @s) ixUncons dim idx) `SCons` env) b)
+ sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a
+ arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b)
EFold1Inner _ a b -> do
let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a
arr <- interpret' env b
@@ -75,9 +83,9 @@ interpret' env = \case
EConst _ _ v -> return v
EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
- EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx (Proxy @s) IxNil IxCons n <$> interpret' env b)
- EShape _ e | STArr n _ <- typeOf e -> tupRepIdx (Proxy @s) shUncons n . arrayShape <$> interpret' env e
- EOp _ op e -> interpretOp (Proxy @s) op <$> interpret' env e
+ EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
+ EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e
+ EOp _ op e -> interpretOp op <$> interpret' env e
EWith e1 e2 -> do
initval <- interpret' env e1
withAccum (typeOf e1) initval $ \accum ->
@@ -87,10 +95,16 @@ interpret' env = \case
val <- interpret' env e2
accum <- interpret' env e3
accumAdd accum i idx val
+ EZero t -> do
+ return $ makeZero t
+ EPlus t a b -> do
+ a' <- interpret' env a
+ b' <- interpret' env b
+ return $ makePlus t a' b'
EError _ s -> error $ "Interpreter: Program threw error: " ++ s
-interpretOp :: Proxy s -> SOp a t -> Rep' s a -> Rep' s t
-interpretOp _ op arg = case op of
+interpretOp :: SOp a t -> Rep a -> Rep t
+interpretOp op arg = case op of
OAdd st -> numericIsNum st $ uncurry (+) arg
OMul st -> numericIsNum st $ uncurry (*) arg
ONeg st -> numericIsNum st $ negate arg
@@ -100,23 +114,66 @@ interpretOp _ op arg = case op of
ONot -> not arg
OIf -> if arg then Left () else Right ()
+makeZero :: STy t -> Rep (D2 t)
+makeZero typ = case typ of
+ STNil -> ()
+ STPair _ _ -> Left ()
+ STEither _ _ -> Left ()
+ STMaybe _ -> Nothing
+ STArr n _ -> emptyArray n
+ STScal sty -> case sty of
+ STI32 -> ()
+ STI64 -> ()
+ STF32 -> 0.0
+ STF64 -> 0.0
+ STBool -> ()
+ STAccum{} -> error "Zero of Accum"
+
+makePlus :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t)
+makePlus typ a b = case typ of
+ STNil -> ()
+ STPair t1 t2 -> case (a, b) of
+ (Left (), _) -> b
+ (_, Left ()) -> a
+ (Right (x1, x2), Right (y1, y2)) -> Right (makePlus t1 x1 y1, makePlus t2 x2 y2)
+ STEither t1 t2 -> case (a, b) of
+ (Left (), _) -> b
+ (_, Left ()) -> a
+ (Right (Left x), Right (Left y)) -> Right (Left (makePlus t1 x y))
+ (Right (Right x), Right (Right y)) -> Right (Right (makePlus t2 x y))
+ _ -> error "Plus of inconsistent Eithers"
+ STArr _ t ->
+ let sh1 = arrayShape a
+ sh2 = arrayShape b
+ in if | shapeSize sh1 == 0 -> b
+ | shapeSize sh2 == 0 -> a
+ | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> makePlus t (arrayIndexLinear a i) (arrayIndexLinear b i))
+ | otherwise -> error "Plus of inconsistently shaped arrays"
+ STScal sty -> case sty of
+ STI32 -> ()
+ STI64 -> ()
+ STF32 -> a + b
+ STF64 -> a + b
+ STBool -> ()
+ STAccum{} -> error "Plus of Accum"
+
numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
numericIsNum STI32 = id
numericIsNum STI64 = id
numericIsNum STF32 = id
numericIsNum STF64 = id
-unTupRepIdx :: Proxy s -> f Z -> (forall m. f m -> Int -> f (S m))
- -> SNat n -> Rep' s (Tup (Replicate n TIx)) -> f n
-unTupRepIdx _ nil _ SZ _ = nil
-unTupRepIdx p nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i
+unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m))
+ -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n
+unTupRepIdx nil _ SZ _ = nil
+unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i
-tupRepIdx :: Proxy s -> (forall m. f (S m) -> (f m, Int))
- -> SNat n -> f n -> Rep' s (Tup (Replicate n TIx))
-tupRepIdx _ _ SZ _ = ()
-tupRepIdx p uncons (SS n) tup =
+tupRepIdx :: (forall m. f (S m) -> (f m, Int))
+ -> SNat n -> f n -> Rep (Tup (Replicate n TIx))
+tupRepIdx _ SZ _ = ()
+tupRepIdx uncons (SS n) tup =
let (tup', i) = uncons tup
- in (tupRepIdx p uncons n tup', fromIntegral @Int @Int64 i)
+ in (tupRepIdx uncons n tup', fromIntegral @Int @Int64 i)
ixUncons :: Index (S n) -> (Index n, Int)
ixUncons (IxCons idx i) = (idx, i)
@@ -124,33 +181,6 @@ ixUncons (IxCons idx i) = (idx, i)
shUncons :: Shape (S n) -> (Shape n, Int)
shUncons (ShCons idx i) = (idx, i)
-class NoAccum t where
- noAccum :: Proxy s -> Proxy t -> Rep' s t :~: Rep t
-instance NoAccum TNil where
- noAccum _ _ = Refl
-instance (NoAccum a, NoAccum b) => NoAccum (TPair a b) where
- noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl
-instance (NoAccum a, NoAccum b) => NoAccum (TEither a b) where
- noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl
-instance NoAccum t => NoAccum (TArr n t) where
- noAccum p _ | Refl <- noAccum p (Proxy @t) = Refl
-instance NoAccum (TScal t) where
- noAccum _ _ = Refl
-
-unAccum :: Proxy s -> STy t -> Maybe (Dict (NoAccum t))
-unAccum _ STNil = Just Dict
-unAccum p (STPair t1 t2)
- | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict
- | otherwise = Nothing
-unAccum p (STEither t1 t2)
- | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict
- | otherwise = Nothing
-unAccum p (STArr _ t)
- | Just Dict <- unAccum p t = Just Dict
- | otherwise = Nothing
-unAccum _ STScal{} = Just Dict
-unAccum _ STAccum{} = Nothing
-
foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a
foldl1M _ [] = error "foldl1M: empty list"
foldl1M f (tophead : toptail) = foldM f tophead toptail
diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs
index b6a91df..af7be1e 100644
--- a/src/Interpreter/Accum.hs
+++ b/src/Interpreter/Accum.hs
@@ -51,6 +51,7 @@ type family Rep' s t where
Rep' s TNil = ()
Rep' s (TPair a b) = (Rep' s a, Rep' s b)
Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b)
+ Rep' s (TMaybe t) = Maybe (Rep' s t)
Rep' s (TArr n t) = Array n (Rep' s t)
Rep' s (TScal sty) = ScalRep sty
Rep' s (TAccum t) = Accum s t
@@ -61,16 +62,13 @@ data Accum s t = Accum (STy t) (ForeignPtr ())
tSize :: Proxy s -> STy t -> Rep' s t -> Int
tSize p ty x = tSize' p ty (Just x)
--- | Passing Nothing as the value means "this is (inside) an array element".
-tSize' :: Proxy s -> STy t -> Maybe (Rep' s t) -> Int
-tSize' p typ val = case typ of
+tSize' :: Proxy s -> STy t -> Int
+tSize' p typ = case typ of
STNil -> 0
- STPair a b -> tSize' p a (fst <$> val) + tSize' p b (snd <$> val)
- STEither a b ->
- case val of
- Nothing -> 1 + max (tSize' p a Nothing) (tSize' p b Nothing)
- Just (Left x) -> 1 + tSize' p a (Just x) -- '1 +' is for runtime sanity checking
- Just (Right y) -> 1 + tSize' p b (Just y) -- idem
+ STPair a b -> tSize' p a + tSize' p b
+ STEither a b -> 1 + max (tSize' p a) (tSize' p b)
+ -- Representation of Maybe t is the same as Either () t; the add operation is different, however.
+ STMaybe t -> tSize' p (STEither STNil t)
STArr ndim t ->
case val of
Nothing -> error "Nested arrays not supported in this implementation"
@@ -99,7 +97,7 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
go inarr b (snd val) off1
STEither a b -> do
let !(I# off#) = off
- case val of
+ off1 <- case val of
Left x -> do
let !(I8# tag#) = 0
writeInt8# addr# off# tag#
@@ -108,6 +106,11 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
let !(I8# tag#) = 1
writeInt8# addr# off# tag#
go inarr b y (off + 1)
+ if inarr
+ then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing))
+ else return off1
+ -- Representation is the same, but add operation is different
+ STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off
STArr _ t
| inarr -> error "Nested arrays not supported in this implementation"
| otherwise -> do
@@ -158,6 +161,8 @@ accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
if inarr
then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val)
else return (off1, val)
+ -- Representation is the same, but add operation is different
+ STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off
STArr ndim t
| inarr -> error "Nested arrays not supported in this implementation"
| otherwise -> do
@@ -219,6 +224,7 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr
(STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
(STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
(STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd"
+ (STMaybe t, _, _) -> _ idx val
(STArr rank eltty, _, _)
| inarr -> error "Nested arrays"
| otherwise -> do
diff --git a/src/Interpreter/AccumOld.hs b/src/Interpreter/AccumOld.hs
new file mode 100644
index 0000000..af7be1e
--- /dev/null
+++ b/src/Interpreter/AccumOld.hs
@@ -0,0 +1,366 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UnboxedTuples #-}
+module Interpreter.Accum (
+ AcM,
+ runAcM,
+ Rep',
+ Accum,
+ withAccum,
+ accumAdd,
+ inParallel,
+) where
+
+import Control.Concurrent
+import Control.Monad (when, forM_)
+import Data.Bifunctor (second)
+import Data.Proxy
+import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr)
+import Foreign.Storable (sizeOf)
+import GHC.Exts
+import GHC.Float
+import GHC.Int
+import GHC.IO (IO(..))
+import GHC.Word
+import System.IO.Unsafe (unsafePerformIO)
+
+import Array
+import AST
+import Data
+
+
+newtype AcM s a = AcM (IO a)
+ deriving newtype (Functor, Applicative, Monad)
+
+runAcM :: (forall s. AcM s a) -> a
+runAcM (AcM m) = unsafePerformIO m
+
+-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined.
+type family Rep' s t where
+ Rep' s TNil = ()
+ Rep' s (TPair a b) = (Rep' s a, Rep' s b)
+ Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b)
+ Rep' s (TMaybe t) = Maybe (Rep' s t)
+ Rep' s (TArr n t) = Array n (Rep' s t)
+ Rep' s (TScal sty) = ScalRep sty
+ Rep' s (TAccum t) = Accum s t
+
+-- | Floats and integers are accumulated; booleans are left as-is.
+data Accum s t = Accum (STy t) (ForeignPtr ())
+
+tSize :: Proxy s -> STy t -> Rep' s t -> Int
+tSize p ty x = tSize' p ty (Just x)
+
+tSize' :: Proxy s -> STy t -> Int
+tSize' p typ = case typ of
+ STNil -> 0
+ STPair a b -> tSize' p a + tSize' p b
+ STEither a b -> 1 + max (tSize' p a) (tSize' p b)
+ -- Representation of Maybe t is the same as Either () t; the add operation is different, however.
+ STMaybe t -> tSize' p (STEither STNil t)
+ STArr ndim t ->
+ case val of
+ Nothing -> error "Nested arrays not supported in this implementation"
+ Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing
+ STScal sty -> goScal sty
+ STAccum{} -> error "Nested accumulators unsupported"
+ where
+ goScal :: SScalTy t -> Int
+ goScal STI32 = 4
+ goScal STI64 = 8
+ goScal STF32 = 4
+ goScal STF64 = 8
+ goScal STBool = 1
+
+-- | This operation does not commute with 'accumAdd', so it must be used with
+-- care. Furthermore it must be used on exactly the same value as tSize was
+-- called on. Hence it lives in IO, not in AcM.
+accumWrite :: forall s t. Accum s t -> Rep' s t -> IO ()
+accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
+ let
+ go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
+ go inarr ty val off = case ty of
+ STNil -> return off
+ STPair a b -> do
+ off1 <- go inarr a (fst val) off
+ go inarr b (snd val) off1
+ STEither a b -> do
+ let !(I# off#) = off
+ off1 <- case val of
+ Left x -> do
+ let !(I8# tag#) = 0
+ writeInt8# addr# off# tag#
+ go inarr a x (off + 1)
+ Right y -> do
+ let !(I8# tag#) = 1
+ writeInt8# addr# off# tag#
+ go inarr b y (off + 1)
+ if inarr
+ then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing))
+ else return off1
+ -- Representation is the same, but add operation is different
+ STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off
+ STArr _ t
+ | inarr -> error "Nested arrays not supported in this implementation"
+ | otherwise -> do
+ off1 <- goShape (arrayShape val) off
+ let eltsize = tSize' (Proxy @s) t Nothing
+ n = arraySize val
+ traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val
+ return (off1 + eltsize * n)
+ STScal sty -> goScal sty val off
+ STAccum{} -> error "Nested accumulators unsupported"
+
+ goShape :: Shape n -> Int -> IO Int
+ goShape ShNil off = return off
+ goShape (ShCons sh n) off = do
+ off1@(I# off1#) <- goShape sh off
+ let !(I64# n'#) = fromIntegral n
+ writeInt64# addr# off1# n'#
+ return (off1 + 8)
+
+ goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int
+ goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x
+ goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x
+ goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x
+ goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x
+ goScal STBool b off@(I# off#) = do
+ let !(I8# i) = fromIntegral (fromEnum b)
+ off + 1 <$ writeInt8# addr# off# i
+
+ in () <$ go False topty top_value 0
+
+accumRead :: forall s t. Accum s t -> AcM s (Rep' s t)
+accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
+ let
+ go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t')
+ go inarr ty off = case ty of
+ STNil -> return (off, ())
+ STPair a b -> do
+ (off1, x) <- go inarr a off
+ (off2, y) <- go inarr b off1
+ return (off1 + off2, (x, y))
+ STEither a b -> do
+ let !(I# off#) = off
+ tag <- readInt8 addr# off#
+ (off1, val) <- case tag of
+ 0 -> fmap Left <$> go inarr a (off + 1)
+ 1 -> fmap Right <$> go inarr b (off + 1)
+ _ -> error "Invalid tag in accum memory"
+ if inarr
+ then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val)
+ else return (off1, val)
+ -- Representation is the same, but add operation is different
+ STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off
+ STArr ndim t
+ | inarr -> error "Nested arrays not supported in this implementation"
+ | otherwise -> do
+ (off1, sh) <- readShape addr# ndim off
+ let eltsize = tSize' (Proxy @s) t Nothing
+ n = shapeSize sh
+ arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini))
+ return (off1 + eltsize * n, arr)
+ STScal sty -> goScal sty off
+ STAccum{} -> error "Nested accumulators unsupported"
+
+ goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t')
+ goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off#
+ goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off#
+ goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off#
+ goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off#
+ goScal STBool off@(I# off#) = do
+ i8 <- readInt8 addr# off#
+ return (off + 1, toEnum (fromIntegral i8))
+
+ in snd <$> go False topty 0
+
+readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n)
+readShape _ SZ off = return (off, ShNil)
+readShape mbarr (SS ndim) off = do
+ (off1@(I# off1#), sh) <- readShape mbarr ndim off
+ n' <- readInt64 mbarr off1#
+ return (off1 + 8, ShCons sh (fromIntegral n'))
+
+-- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of
+-- the list.
+data InvShape n where
+ IShNil :: InvShape Z
+ IShCons :: Int -- ^ How many subarrays are there?
+ -> Int -- ^ What is the size of all subarrays together?
+ -> InvShape n -- ^ Sub array inverted shape
+ -> InvShape (S n)
+
+ishSize :: InvShape n -> Int
+ishSize IShNil = 1
+ishSize (IShCons _ sz _) = sz
+
+invertShape :: forall n. Shape n -> InvShape n
+invertShape | Refl <- lemPlusZero @n = flip go IShNil
+ where
+ go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m)
+ go ShNil ish = ish
+ go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish)
+
+accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s ()
+accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
+ let
+ go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO ()
+ go inarr ty SZ () val off = () <$ performAdd inarr ty val off
+ go inarr ty (SS dep) idx val off = case (ty, idx, val) of
+ (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
+ (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
+ (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd"
+ (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
+ (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off
+ (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd"
+ (STMaybe t, _, _) -> _ idx val
+ (STArr rank eltty, _, _)
+ | inarr -> error "Nested arrays"
+ | otherwise -> do
+ (off1, ish) <- second invertShape <$> readShape addr# rank off
+ goArr (SS dep) ish eltty idx val off1
+ (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth"
+ (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth"
+ (STAccum{}, _, _) -> error "Nested accumulators unsupported"
+
+ goArr :: SNat i' -> InvShape n -> STy t'
+ -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO ()
+ goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off
+ goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off
+ goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do
+ let i' = fromIntegral @(Rep' s TIx) @Int i
+ when (i' < 0 || i' >= n) $
+ error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")"
+ goArr depm1 ish t1 idx val (off + i' * ishSize ish)
+
+ performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int
+ performAddArr arraySz eltty val off = do
+ let eltsize = tSize' (Proxy @s) eltty Nothing
+ forM_ [0 .. arraySz - 1] $ \lini ->
+ performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize)
+ return (off + arraySz * eltsize)
+
+ performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
+ performAdd inarr ty val off = case ty of
+ STNil -> return off
+ STPair t1 t2 -> do
+ off1 <- performAdd inarr t1 (fst val) off
+ performAdd inarr t2 (snd val) off1
+ STEither t1 t2 -> do
+ let !(I# off#) = off
+ tag <- readInt8 addr# off#
+ off1 <- case (val, tag) of
+ (Left val1, 0) -> performAdd inarr t1 val1 (off + 1)
+ (Right val2, 1) -> performAdd inarr t2 val2 (off + 1)
+ _ -> error "accumAdd: Tag mismatch for Either"
+ if inarr
+ then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing))
+ else return off1
+ STArr n ty'
+ | inarr -> error "Nested array"
+ | otherwise -> do
+ (off1, sh) <- readShape addr# n off
+ performAddArr (shapeSize sh) ty' val off1
+ STScal ty' -> performAddScal ty' val off
+ STAccum{} -> error "Nested accumulators unsupported"
+
+ performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int
+ performAddScal STI32 (I32# x#) off@(I# off#)
+ | sizeOf (undefined :: Int) == 4
+ = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#))
+ | otherwise
+ = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#))
+ performAddScal STI64 (I64# x#) off@(I# off#)
+ | sizeOf (undefined :: Int) == 8
+ = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#))
+ | otherwise
+ = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#))
+ performAddScal STF32 x off@(I# off#) =
+ off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w))
+ performAddScal STF64 x off@(I# off#) =
+ off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w))
+ performAddScal STBool _ off = return (off + 1) -- don't do anything with booleans
+
+ casLoop :: Eq w
+ => (Addr# -> Int# -> IO w) -- ^ read value (from a given byte offset; will get 0#)
+ -> (Addr# -> w -> w -> IO w) -- ^ CAS value at address (expected -> desired -> IO observed)
+ -> Addr# -- ^ Address to attempt to modify
+ -> (w -> w) -- ^ Operation to apply to the value
+ -> IO ()
+ casLoop readOp casOp addr modify = readOp addr 0# >>= loop
+ where
+ loop value = do
+ value' <- casOp addr value (modify value)
+ if value == value'
+ then return ()
+ else loop value'
+
+ in () <$ go False topty top_depth top_index top_value 0
+
+withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t)
+withAccum ty start fun = do
+ -- The initial write must happen before any of the adds or reads, so it makes
+ -- sense to put it in IO together with the allocation, instead of in AcM.
+ accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start)
+ ptr <- newForeignPtr finalizerFree buffer
+ let accum = Accum ty ptr
+ accumWrite accum start
+ return accum
+ b <- fun accum
+ out <- accumRead accum
+ return (b, out)
+
+inParallel :: [AcM s t] -> AcM s [t]
+inParallel actions = AcM $ do
+ mvars <- mapM (\_ -> newEmptyMVar) actions
+ forM_ (zip actions mvars) $ \(AcM action, var) ->
+ forkIO $ action >>= putMVar var
+ mapM takeMVar mvars
+
+-- | Offset is in bytes.
+readInt8 :: Addr# -> Int# -> IO Int8
+readInt32 :: Addr# -> Int# -> IO Int32
+readInt64 :: Addr# -> Int# -> IO Int64
+readWord32 :: Addr# -> Int# -> IO Word32
+readWord64 :: Addr# -> Int# -> IO Word64
+readFloat :: Addr# -> Int# -> IO Float
+readDouble :: Addr# -> Int# -> IO Double
+readInt8 addr off# = IO $ \s -> case readInt8OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8# val #)
+readInt32 addr off# = IO $ \s -> case readInt32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #)
+readInt64 addr off# = IO $ \s -> case readInt64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #)
+readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #)
+readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #)
+readFloat addr off# = IO $ \s -> case readFloatOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F# val #)
+readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D# val #)
+
+writeInt8# :: Addr# -> Int# -> Int8# -> IO ()
+writeInt32# :: Addr# -> Int# -> Int32# -> IO ()
+writeInt64# :: Addr# -> Int# -> Int64# -> IO ()
+writeFloat# :: Addr# -> Int# -> Float# -> IO ()
+writeDouble# :: Addr# -> Int# -> Double# -> IO ()
+writeInt8# addr off# val = IO $ \s -> (# writeInt8OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeInt32# addr off# val = IO $ \s -> (# writeInt32OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeInt64# addr off# val = IO $ \s -> (# writeInt64OffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeFloat# addr off# val = IO $ \s -> (# writeFloatOffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #)
+
+fetchAddWord# :: Addr# -> Int# -> Word# -> IO ()
+fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #)
+
+atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32
+atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64
+atomicCasWord32Addr addr (W32# expected) (W32# desired) =
+ IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #)
+atomicCasWord64Addr addr (W64# expected) (W64# desired) =
+ IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #)
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index 1ded773..7add442 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -2,6 +2,8 @@
{-# LANGUAGE TypeFamilies #-}
module Interpreter.Rep where
+import Data.IORef
+import qualified Data.Vector.Mutable as MV
import GHC.TypeError
import Array
@@ -12,6 +14,22 @@ type family Rep t where
Rep TNil = ()
Rep (TPair a b) = (Rep a, Rep b)
Rep (TEither a b) = Either (Rep a) (Rep b)
+ Rep (TMaybe t) = Maybe (Rep t)
Rep (TArr n t) = Array n (Rep t)
Rep (TScal sty) = ScalRep sty
- Rep (TAccum t) = TypeError (Text "Accumulator in Rep")
+ Rep (TAccum t) = IORef (RepAc t)
+
+type family RepAc t where
+ RepAc TNil = ()
+ RepAc (TPair a b) = (RepAc a, RepAc b)
+ -- This is annoying when working with values of type 'RepAc t', because
+ -- failing a pattern match does not generate negative type information.
+ -- However, it works, saves us from having to defining a LEither type
+ -- first-class in the type system with
+ -- Rep (LEither a b) = Maybe (Either a b)
+ -- and it's not even incorrect, in a way.
+ RepAc (TMaybe (TEither a b)) = IORef (Maybe (Either (RepAc a) (RepAc b)))
+ RepAc (TMaybe t) = IORef (Maybe (RepAc t))
+ RepAc (TArr n t) = (Shape n, MV.IOVector (RepAc t))
+ RepAc (TScal sty) = IORef (ScalRep sty)
+ RepAc (TAccum t) = TypeError (Text "Nested accumulators")
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 1640729..3ac68ed 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -68,6 +68,8 @@ simplify' = \case
-- TODO: constant folding for operations
+ -- TODO: accum of zero, plus of zero
+
EVar _ t i -> EVar ext t i
ELet _ a b -> ELet ext (simplify' a) (simplify' b)
EPair _ a b -> EPair ext (simplify' a) (simplify' b)
@@ -92,6 +94,8 @@ simplify' = \case
EOp _ op e -> EOp ext op (simplify' e)
EWith e1 e2 -> EWith (simplify' e1) (let ?accumInScope = True in simplify' e2)
EAccum i e1 e2 e3 -> EAccum i (simplify' e1) (simplify' e2) (simplify' e3)
+ EZero t -> EZero t
+ EPlus t a b -> EPlus t (simplify' a) (simplify' b)
EError t s -> EError t s
cheapExpr :: Expr x env t -> Bool
@@ -129,6 +133,8 @@ hasAdds = \case
EOp _ _ e -> hasAdds e
EWith a b -> hasAdds a || hasAdds b
EAccum _ _ _ _ -> True
+ EZero _ -> False
+ EPlus _ a b -> hasAdds a || hasAdds b
EError _ _ -> False
checkAccumInScope :: SList STy env -> Bool