diff options
-rw-r--r-- | chad-fast.cabal | 4 | ||||
-rw-r--r-- | src/AST.hs | 109 | ||||
-rw-r--r-- | src/AST/Count.hs | 3 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 7 | ||||
-rw-r--r-- | src/AST/Types.hs | 95 | ||||
-rw-r--r-- | src/Array.hs | 15 | ||||
-rw-r--r-- | src/CHAD.hs | 136 | ||||
-rw-r--r-- | src/CHAD/Types.hs | 65 | ||||
-rw-r--r-- | src/Interpreter.hs | 142 | ||||
-rw-r--r-- | src/Interpreter/Accum.hs | 26 | ||||
-rw-r--r-- | src/Interpreter/AccumOld.hs | 366 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 20 | ||||
-rw-r--r-- | src/Simplify.hs | 6 |
13 files changed, 756 insertions, 238 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 6314acd..1b95c66 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -15,14 +15,16 @@ library AST.Count AST.Env AST.Pretty + AST.Types AST.Weaken AST.Weaken.Auto CHAD + CHAD.Types -- Compile Data Example Interpreter - Interpreter.Accum + -- Interpreter.AccumOld Interpreter.Rep Language Language.AST @@ -13,92 +13,19 @@ {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module AST (module AST, module AST.Weaken) where +module AST (module AST, module AST.Types, module AST.Weaken) where import Data.Functor.Const import Data.Kind (Type) -import Data.Int -import Data.Type.Equality import Array import AST.Env +import AST.Types import AST.Weaken +import CHAD.Types import Data -data Ty - = TNil - | TPair Ty Ty - | TEither Ty Ty - | TArr Nat Ty -- ^ rank, element type - | TScal ScalTy - | TAccum Ty - deriving (Show, Eq, Ord) - -data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool - deriving (Show, Eq, Ord) - -type STy :: Ty -> Type -data STy t where - STNil :: STy TNil - STPair :: STy a -> STy b -> STy (TPair a b) - STEither :: STy a -> STy b -> STy (TEither a b) - STArr :: SNat n -> STy t -> STy (TArr n t) - STScal :: SScalTy t -> STy (TScal t) - STAccum :: STy t -> STy (TAccum t) -deriving instance Show (STy t) - -instance TestEquality STy where - testEquality STNil STNil = Just Refl - testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl - testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl - testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl - testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl - testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl - testEquality _ _ = Nothing - -data SScalTy t where - STI32 :: SScalTy TI32 - STI64 :: SScalTy TI64 - STF32 :: SScalTy TF32 - STF64 :: SScalTy TF64 - STBool :: SScalTy TBool -deriving instance Show (SScalTy t) - -instance TestEquality SScalTy where - testEquality STI32 STI32 = Just Refl - testEquality STI64 STI64 = Just Refl - testEquality STF32 STF32 = Just Refl - testEquality STF64 STF64 = Just Refl - testEquality STBool STBool = Just Refl - testEquality _ _ = Nothing - -scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) -scalRepIsShow STI32 = Dict -scalRepIsShow STI64 = Dict -scalRepIsShow STF32 = Dict -scalRepIsShow STF64 = Dict -scalRepIsShow STBool = Dict - -type TIx = TScal TI64 - -tIx :: STy TIx -tIx = STScal STI64 - -type family ScalRep t where - ScalRep TI32 = Int32 - ScalRep TI64 = Int64 - ScalRep TF32 = Float - ScalRep TF64 = Double - ScalRep TBool = Bool - -type family ScalIsNumeric t where - ScalIsNumeric TI32 = True - ScalIsNumeric TI64 = True - ScalIsNumeric TF32 = True - ScalIsNumeric TF64 = True - ScalIsNumeric TBool = False - -- | This index is flipped around from the usual direction: the smallest index -- is at the _heart_ of the nesting, not at the outside. The outermost layer -- indexes into the _outer_ dimension of the type @t@. This makes indices into @@ -107,6 +34,7 @@ type family AcIdx t i where AcIdx t Z = TNil AcIdx (TPair a b) (S i) = TEither (AcIdx a i) (AcIdx b i) AcIdx (TEither a b) (S i) = TEither (AcIdx a i) (AcIdx b i) + AcIdx (TMaybe t) (S i) = AcIdx t i AcIdx (TArr Z t) (S i) = AcIdx t i AcIdx (TArr (S n) t) (S i) = TPair TIx (AcIdx (TArr n t) i) @@ -114,12 +42,20 @@ type family AcVal t i where AcVal t Z = t AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i) AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i) + AcVal (TMaybe t) (S i) = AcVal t i AcVal (TArr Z t) (S i) = AcVal t i AcVal (TArr (S n) t) (S i) = AcVal (TArr n t) i -- General assumption: head of the list (whatever way it is associated) is the -- inner variable / inner array dimension. In pretty printing, the inner -- variable / inner dimension is printed on the _right_. +-- +-- Note that the 'EZero' and 'EPlus' constructs have typing that depend on the +-- type transformation of CHAD. Indeed, these constructors are created _by_ +-- CHAD, and are intended to be eliminated after simplification, so that the +-- input program as well as the output program do not contain these +-- constructors. +-- TODO: ensure this by a "stage" type parameter. type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type data Expr x env t where -- lambda calculus @@ -134,6 +70,9 @@ data Expr x env t where EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b) EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b) ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c + ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) + EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) + EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) @@ -157,6 +96,10 @@ data Expr x env t where EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil -- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil + -- monoidal operations (to be desugared to regular operations after simplification) + EZero :: STy t -> Expr x env (D2 t) + EPlus :: STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t) + -- partiality EError :: STy a -> String -> Expr x env a deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) @@ -220,6 +163,9 @@ typeOf = \case EInl _ t2 e -> STEither (typeOf e) t2 EInr _ t1 e -> STEither t1 (typeOf e) ECase _ _ a _ -> typeOf a + ENothing _ t -> STMaybe t + EJust _ e -> STMaybe (typeOf e) + EMaybe _ e _ _ -> typeOf e EConstArr _ n t _ -> STArr n (STScal t) EBuild1 _ _ e -> STArr (SS SZ) (typeOf e) @@ -239,6 +185,9 @@ typeOf = \case EWith e1 e2 -> STPair (typeOf e2) (typeOf e1) EAccum _ _ _ _ -> STNil + EZero t -> d2 t + EPlus t _ _ -> d2 t + EError t _ -> t unSNat :: SNat n -> Nat @@ -250,6 +199,7 @@ unSTy = \case STNil -> TNil STPair a b -> TPair (unSTy a) (unSTy b) STEither a b -> TEither (unSTy a) (unSTy b) + STMaybe t -> TMaybe (unSTy t) STArr n t -> TArr (unSNat n) (unSTy t) STScal t -> TScal (unSScalTy t) STAccum t -> TAccum (unSTy t) @@ -288,6 +238,9 @@ subst' f w = \case EInl x t e -> EInl x t (subst' f w e) EInr x t e -> EInr x t (subst' f w e) ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) + ENothing x t -> ENothing x t + EJust x e -> EJust x (subst' f w e) + EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e) EConstArr x n t a -> EConstArr x n t a EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b) EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) @@ -303,6 +256,8 @@ subst' f w = \case EOp x op e -> EOp x op (subst' f w e) EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3) + EZero t -> EZero t + EPlus t a b -> EPlus t (subst' f w a) (subst' f w b) EError t s -> EError t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) @@ -339,6 +294,7 @@ class KnownTy t where knownTy :: STy t instance KnownTy TNil where knownTy = STNil instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy +instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy @@ -351,6 +307,7 @@ styKnown :: STy t -> Dict (KnownTy t) styKnown STNil = Dict styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STMaybe t) | Dict <- styKnown t = Dict styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict styKnown (STScal t) | Dict <- sscaltyKnown t = Dict styKnown (STAccum t) | Dict <- styKnown t = Dict diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 6a00e83..40a46f6 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -46,6 +46,7 @@ Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) -- | This code is executed many times scaleMany :: Occ -> Occ +scaleMany (Occ l Zero) = Occ l Zero scaleMany (Occ l _) = Occ l Many occCount :: Idx env a -> Expr x env t -> Occ @@ -124,6 +125,8 @@ occCountGeneral onehot unpush alter many = go WId EOp _ _ e -> re e EWith a b -> re a <> re1 b EAccum _ a b e -> re a <> re b <> re e + EZero _ -> mempty + EPlus _ a b -> re a <> re b EError{} -> mempty where re :: Monoid (r env') => Expr x env' t'' -> r env' diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 2ce883b..f5e681a 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -181,6 +181,13 @@ ppExpr' d val = \case return $ showParen (d > 10) $ showString ("accum " ++ show (unSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3' + EZero _ -> return $ showString "zero" + + EPlus _ a b -> do + a' <- ppExpr' 11 val a + b' <- ppExpr' 11 val b + return $ showParen (d > 10) $ showString "plus " . a' . showString " " . b' + EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s) ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS diff --git a/src/AST/Types.hs b/src/AST/Types.hs new file mode 100644 index 0000000..a3e5080 --- /dev/null +++ b/src/AST/Types.hs @@ -0,0 +1,95 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeFamilies #-} +module AST.Types where + +import Data.Int (Int32, Int64) +import Data.Kind (Type) +import Data.Type.Equality + +import Data + + +data Ty + = TNil + | TPair Ty Ty + | TEither Ty Ty + | TMaybe Ty + | TArr Nat Ty -- ^ rank, element type + | TScal ScalTy + | TAccum Ty + deriving (Show, Eq, Ord) + +data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool + deriving (Show, Eq, Ord) + +type STy :: Ty -> Type +data STy t where + STNil :: STy TNil + STPair :: STy a -> STy b -> STy (TPair a b) + STEither :: STy a -> STy b -> STy (TEither a b) + STMaybe :: STy a -> STy (TMaybe a) + STArr :: SNat n -> STy t -> STy (TArr n t) + STScal :: SScalTy t -> STy (TScal t) + STAccum :: STy t -> STy (TAccum t) +deriving instance Show (STy t) + +instance TestEquality STy where + testEquality STNil STNil = Just Refl + testEquality STNil _ = Nothing + testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl + testEquality STPair{} _ = Nothing + testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl + testEquality STEither{} _ = Nothing + testEquality (STMaybe a) (STMaybe a') | Just Refl <- testEquality a a' = Just Refl + testEquality STMaybe{} _ = Nothing + testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl + testEquality STArr{} _ = Nothing + testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl + testEquality STScal{} _ = Nothing + testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl + testEquality STAccum{} _ = Nothing + +data SScalTy t where + STI32 :: SScalTy TI32 + STI64 :: SScalTy TI64 + STF32 :: SScalTy TF32 + STF64 :: SScalTy TF64 + STBool :: SScalTy TBool +deriving instance Show (SScalTy t) + +instance TestEquality SScalTy where + testEquality STI32 STI32 = Just Refl + testEquality STI64 STI64 = Just Refl + testEquality STF32 STF32 = Just Refl + testEquality STF64 STF64 = Just Refl + testEquality STBool STBool = Just Refl + testEquality _ _ = Nothing + +scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) +scalRepIsShow STI32 = Dict +scalRepIsShow STI64 = Dict +scalRepIsShow STF32 = Dict +scalRepIsShow STF64 = Dict +scalRepIsShow STBool = Dict + +type TIx = TScal TI64 + +tIx :: STy TIx +tIx = STScal STI64 + +type family ScalRep t where + ScalRep TI32 = Int32 + ScalRep TI64 = Int64 + ScalRep TF32 = Float + ScalRep TF64 = Double + ScalRep TBool = Bool + +type family ScalIsNumeric t where + ScalIsNumeric TI32 = True + ScalIsNumeric TI64 = True + ScalIsNumeric TF32 = True + ScalIsNumeric TF64 = True + ScalIsNumeric TBool = False diff --git a/src/Array.hs b/src/Array.hs index 9a770c4..0d585a9 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -17,11 +17,13 @@ data Shape n where ShNil :: Shape Z ShCons :: Shape n -> Int -> Shape (S n) deriving instance Show (Shape n) +deriving instance Eq (Shape n) data Index n where IxNil :: Index Z IxCons :: Index n -> Int -> Index (S n) deriving instance Show (Index n) +deriving instance Eq (Index n) shapeSize :: Shape n -> Int shapeSize ShNil = 0 @@ -38,6 +40,10 @@ toLinearIndex :: Shape n -> Index n -> Int toLinearIndex ShNil IxNil = 0 toLinearIndex (sh `ShCons` n) (idx `IxCons` i) = toLinearIndex sh idx * n + i +emptyShape :: SNat n -> Shape n +emptyShape SZ = ShNil +emptyShape (SS m) = emptyShape m `ShCons` 0 + -- | TODO: this Vector is a boxed vector, which is horrendously inefficient. data Array (n :: Nat) t = Array (Shape n) (Vector t) @@ -49,6 +55,9 @@ arrayShape (Array sh _) = sh arraySize :: Array n t -> Int arraySize (Array sh _) = shapeSize sh +emptyArray :: SNat n -> Array n t +emptyArray n = Array (emptyShape n) V.empty + arrayIndex :: Array n t -> Index n -> t arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx) @@ -58,6 +67,12 @@ arrayIndexLinear (Array _ v) i = v V.! i arrayIndex1 :: Array (S n) t -> Int -> Array n t arrayIndex1 (Array (sh `ShCons` _) v) i = let sz = shapeSize sh in Array sh (V.slice (sz * i) sz v) +arrayGenerate :: Shape n -> (Index n -> t) -> Array n t +arrayGenerate sh f = arrayGenerateLin sh (f . fromLinearIndex sh) + +arrayGenerateLin :: Shape n -> (Int -> t) -> Array n t +arrayGenerateLin sh f = Array sh (V.generate (shapeSize sh) f) + arrayGenerateM :: Monad m => Shape n -> (Index n -> m t) -> m (Array n t) arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh) diff --git a/src/CHAD.hs b/src/CHAD.hs index 7747d46..1ab2da0 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -37,6 +37,7 @@ import AST import AST.Count import AST.Env import AST.Weaken.Auto +import CHAD.Types import Data import Lemmas @@ -288,66 +289,12 @@ vectorise1Binds env n (bs `BPush` (t, e)) = (vectoriseExpr SNil (bindingsBinds bs) env e) in bs' `BPush` (STArr (SS SZ) t, e') -type family D1 t where - D1 TNil = TNil - D1 (TPair a b) = TPair (D1 a) (D1 b) - D1 (TEither a b) = TEither (D1 a) (D1 b) - D1 (TArr n t) = TArr n (D1 t) - D1 (TScal t) = TScal t - -type family D2 t where - D2 TNil = TNil - D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b)) - D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b)) - D2 (TArr n t) = TArr n (D2 t) - D2 (TScal t) = D2s t - -type family D2s t where - D2s TI32 = TNil - D2s TI64 = TNil - D2s TF32 = TScal TF32 - D2s TF64 = TScal TF64 - D2s TBool = TNil - -type family D1E env where - D1E '[] = '[] - D1E (t : env) = D1 t : D1E env - -type family D2E env where - D2E '[] = '[] - D2E (t : env) = D2 t : D2E env - -type family D2AcE env where - D2AcE '[] = '[] - D2AcE (t : env) = TAccum (D2 t) : D2AcE env - -- | Select only the types from the environment that have the specified storage type family Select env sto s where Select '[] '[] _ = '[] Select (t : ts) (s : sto) s = t : Select ts sto s Select (_ : ts) (_ : sto) s = Select ts sto s -d1 :: STy t -> STy (D1 t) -d1 STNil = STNil -d1 (STPair a b) = STPair (d1 a) (d1 b) -d1 (STEither a b) = STEither (d1 a) (d1 b) -d1 (STArr n t) = STArr n (d1 t) -d1 (STScal t) = STScal t -d1 STAccum{} = error "Accumulators not allowed in input program" - -d2 :: STy t -> STy (D2 t) -d2 STNil = STNil -d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b)) -d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b)) -d2 (STArr n t) = STArr n (d2 t) -d2 (STScal t) = case t of - STI32 -> STNil - STI64 -> STNil - STF32 -> STScal STF32 - STF64 -> STScal STF64 - STBool -> STNil -d2 STAccum{} = error "Accumulators not allowed in input program" - conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) conv1Idx IZ = IZ conv1Idx (IS i) = IS (conv1Idx i) @@ -362,48 +309,49 @@ conv2Idx (DPush des (_, SMerge)) (IS i) = second IS (conv2Idx des i) conv2Idx DTop i = case i of {} zero :: STy t -> Ex env (D2 t) -zero STNil = ENil ext -zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext) -zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext) -zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t) -zero (STScal t) = case t of - STI32 -> ENil ext - STI64 -> ENil ext - STF32 -> EConst ext STF32 0.0 - STF64 -> EConst ext STF64 0.0 - STBool -> ENil ext -zero STAccum{} = error "Accumulators not allowed in input program" +zero = EZero +-- TODO: this original definition needs to be used as the post-processing after +-- simplification, to eliminate the monoid operations from the AST +-- zero STNil = ENil ext +-- zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext) +-- zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext) +-- zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t) +-- zero (STScal t) = case t of +-- STI32 -> ENil ext +-- STI64 -> ENil ext +-- STF32 -> EConst ext STF32 0.0 +-- STF64 -> EConst ext STF64 0.0 +-- STBool -> ENil ext +-- zero STAccum{} = error "Accumulators not allowed in input program" plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t) -plus STNil _ _ = ENil ext -plus (STPair t1 t2) a b = - let t = STPair (d2 t1) (d2 t2) - in plusSparse t a b $ - EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) - (EFst ext (EVar ext t IZ))) - (plus t2 (ESnd ext (EVar ext t (IS IZ))) - (ESnd ext (EVar ext t IZ))) -plus (STEither t1 t2) a b = - let t = STEither (d2 t1) (d2 t2) - in plusSparse t a b $ - ECase ext (EVar ext t (IS IZ)) - (ECase ext (EVar ext t (IS IZ)) - (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ))) - (EError t "plus l+r")) - (ECase ext (EVar ext t (IS IZ)) - (EError t "plus r+l") - (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) -plus STArr{} _ _ = error "TODO plus on arrays" - -- 'zero' creates an empty array; this should be a new primitive that - -- (operationally) intelligently memcpy's the non-overlapping part and does - -- a parallel add on the overlapping part. -plus (STScal t) a b = case t of - STI32 -> ENil ext - STI64 -> ENil ext - STF32 -> EOp ext (OAdd STF32) (EPair ext a b) - STF64 -> EOp ext (OAdd STF64) (EPair ext a b) - STBool -> ENil ext -plus STAccum{} _ _ = error "Accumulators not allowed in input program" +plus = EPlus +-- plus STNil _ _ = ENil ext +-- plus (STPair t1 t2) a b = +-- let t = STPair (d2 t1) (d2 t2) +-- in plusSparse t a b $ +-- EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) +-- (EFst ext (EVar ext t IZ))) +-- (plus t2 (ESnd ext (EVar ext t (IS IZ))) +-- (ESnd ext (EVar ext t IZ))) +-- plus (STEither t1 t2) a b = +-- let t = STEither (d2 t1) (d2 t2) +-- in plusSparse t a b $ +-- ECase ext (EVar ext t (IS IZ)) +-- (ECase ext (EVar ext t (IS IZ)) +-- (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ))) +-- (EError t "plus l+r")) +-- (ECase ext (EVar ext t (IS IZ)) +-- (EError t "plus r+l") +-- (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) +-- plus STArr{} _ _ = error "TODO plus on arrays" +-- plus (STScal t) a b = case t of +-- STI32 -> ENil ext +-- STI64 -> ENil ext +-- STF32 -> EOp ext (OAdd STF32) (EPair ext a b) +-- STF64 -> EOp ext (OAdd STF64) (EPair ext a b) +-- STBool -> ENil ext +-- plus STAccum{} _ _ = error "Accumulators not allowed in input program" plusSparse :: STy a -> Ex env (TEither TNil a) -> Ex env (TEither TNil a) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs new file mode 100644 index 0000000..0b32393 --- /dev/null +++ b/src/CHAD/Types.hs @@ -0,0 +1,65 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Types where + +import AST.Types + + +type family D1 t where + D1 TNil = TNil + D1 (TPair a b) = TPair (D1 a) (D1 b) + D1 (TEither a b) = TEither (D1 a) (D1 b) + D1 (TMaybe a) = TMaybe (D1 a) + D1 (TArr n t) = TArr n (D1 t) + D1 (TScal t) = TScal t + +type family D2 t where + D2 TNil = TNil + D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b)) + D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b)) + D2 (TMaybe t) = TMaybe (D2 t) + D2 (TArr n t) = TArr n (D2 t) + D2 (TScal t) = D2s t + +type family D2s t where + D2s TI32 = TNil + D2s TI64 = TNil + D2s TF32 = TScal TF32 + D2s TF64 = TScal TF64 + D2s TBool = TNil + +type family D1E env where + D1E '[] = '[] + D1E (t : env) = D1 t : D1E env + +type family D2E env where + D2E '[] = '[] + D2E (t : env) = D2 t : D2E env + +type family D2AcE env where + D2AcE '[] = '[] + D2AcE (t : env) = TAccum (D2 t) : D2AcE env + +d1 :: STy t -> STy (D1 t) +d1 STNil = STNil +d1 (STPair a b) = STPair (d1 a) (d1 b) +d1 (STEither a b) = STEither (d1 a) (d1 b) +d1 (STMaybe t) = STMaybe (d1 t) +d1 (STArr n t) = STArr n (d1 t) +d1 (STScal t) = STScal t +d1 STAccum{} = error "Accumulators not allowed in input program" + +d2 :: STy t -> STy (D2 t) +d2 STNil = STNil +d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b)) +d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b)) +d2 (STMaybe t) = STMaybe (d2 t) +d2 (STArr n t) = STArr n (d2 t) +d2 (STScal t) = case t of + STI32 -> STNil + STI64 -> STNil + STF32 -> STScal STF32 + STF64 -> STScal STF64 + STBool -> STNil +d2 STAccum{} = error "Accumulators not allowed in input program" diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 7ffb14c..8728ec0 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -1,39 +1,44 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} module Interpreter ( interpret, interpret', Value, - NoAccum(..), - unAccum, ) where +import Control.Monad (foldM) import Data.Int (Int64) import Data.Proxy +import System.IO.Unsafe (unsafePerformIO) +import Array import AST +import CHAD.Types import Data -import Array -import Interpreter.Accum import Interpreter.Rep -import Control.Monad (foldM) -interpret :: NoAccum t => Ex '[] t -> Rep t -interpret e = runAcM (go e) - where - go :: forall s t. NoAccum t => Ex '[] t -> AcM s (Rep t) - go e' | Refl <- noAccum (Proxy @s) (Proxy @t) = interpret' SNil e' +newtype AcM s a = AcM (IO a) + deriving newtype (Functor, Applicative, Monad) -newtype Value s t = Value (Rep' s t) +runAcM :: (forall s. AcM s a) -> a +runAcM (AcM m) = unsafePerformIO m -interpret' :: forall env t s. SList (Value s) env -> Ex env t -> AcM s (Rep' s t) +interpret :: Ex '[] t -> Rep t +interpret e = runAcM (interpret' SNil e) + +newtype Value t = Value (Rep t) + +interpret' :: forall env t s. SList Value env -> Ex env t -> AcM s (Rep t) interpret' env = \case EVar _ _ i -> case slistIdx env i of Value x -> return x ELet _ a b -> do @@ -48,14 +53,17 @@ interpret' env = \case ECase _ e a b -> interpret' env e >>= \case Left x -> interpret' (Value x `SCons` env) a Right y -> interpret' (Value y `SCons` env) b + ENothing _ _ -> _ + EJust _ _ -> _ + EMaybe _ _ _ _ -> _ EConstArr _ _ _ v -> return v EBuild1 _ a b -> do n <- fromIntegral @Int64 @Int <$> interpret' env a arrayGenerateLinM (ShNil `ShCons` n) (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b) EBuild _ dim a b -> do - sh <- unTupRepIdx (Proxy @s) ShNil ShCons dim <$> interpret' env a - arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx (Proxy @s) ixUncons dim idx) `SCons` env) b) + sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a + arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b) EFold1Inner _ a b -> do let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a arr <- interpret' env b @@ -75,9 +83,9 @@ interpret' env = \case EConst _ _ v -> return v EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) - EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx (Proxy @s) IxNil IxCons n <$> interpret' env b) - EShape _ e | STArr n _ <- typeOf e -> tupRepIdx (Proxy @s) shUncons n . arrayShape <$> interpret' env e - EOp _ op e -> interpretOp (Proxy @s) op <$> interpret' env e + EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) + EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e + EOp _ op e -> interpretOp op <$> interpret' env e EWith e1 e2 -> do initval <- interpret' env e1 withAccum (typeOf e1) initval $ \accum -> @@ -87,10 +95,16 @@ interpret' env = \case val <- interpret' env e2 accum <- interpret' env e3 accumAdd accum i idx val + EZero t -> do + return $ makeZero t + EPlus t a b -> do + a' <- interpret' env a + b' <- interpret' env b + return $ makePlus t a' b' EError _ s -> error $ "Interpreter: Program threw error: " ++ s -interpretOp :: Proxy s -> SOp a t -> Rep' s a -> Rep' s t -interpretOp _ op arg = case op of +interpretOp :: SOp a t -> Rep a -> Rep t +interpretOp op arg = case op of OAdd st -> numericIsNum st $ uncurry (+) arg OMul st -> numericIsNum st $ uncurry (*) arg ONeg st -> numericIsNum st $ negate arg @@ -100,23 +114,66 @@ interpretOp _ op arg = case op of ONot -> not arg OIf -> if arg then Left () else Right () +makeZero :: STy t -> Rep (D2 t) +makeZero typ = case typ of + STNil -> () + STPair _ _ -> Left () + STEither _ _ -> Left () + STMaybe _ -> Nothing + STArr n _ -> emptyArray n + STScal sty -> case sty of + STI32 -> () + STI64 -> () + STF32 -> 0.0 + STF64 -> 0.0 + STBool -> () + STAccum{} -> error "Zero of Accum" + +makePlus :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) +makePlus typ a b = case typ of + STNil -> () + STPair t1 t2 -> case (a, b) of + (Left (), _) -> b + (_, Left ()) -> a + (Right (x1, x2), Right (y1, y2)) -> Right (makePlus t1 x1 y1, makePlus t2 x2 y2) + STEither t1 t2 -> case (a, b) of + (Left (), _) -> b + (_, Left ()) -> a + (Right (Left x), Right (Left y)) -> Right (Left (makePlus t1 x y)) + (Right (Right x), Right (Right y)) -> Right (Right (makePlus t2 x y)) + _ -> error "Plus of inconsistent Eithers" + STArr _ t -> + let sh1 = arrayShape a + sh2 = arrayShape b + in if | shapeSize sh1 == 0 -> b + | shapeSize sh2 == 0 -> a + | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> makePlus t (arrayIndexLinear a i) (arrayIndexLinear b i)) + | otherwise -> error "Plus of inconsistently shaped arrays" + STScal sty -> case sty of + STI32 -> () + STI64 -> () + STF32 -> a + b + STF64 -> a + b + STBool -> () + STAccum{} -> error "Plus of Accum" + numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r numericIsNum STI32 = id numericIsNum STI64 = id numericIsNum STF32 = id numericIsNum STF64 = id -unTupRepIdx :: Proxy s -> f Z -> (forall m. f m -> Int -> f (S m)) - -> SNat n -> Rep' s (Tup (Replicate n TIx)) -> f n -unTupRepIdx _ nil _ SZ _ = nil -unTupRepIdx p nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i +unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) + -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n +unTupRepIdx nil _ SZ _ = nil +unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i -tupRepIdx :: Proxy s -> (forall m. f (S m) -> (f m, Int)) - -> SNat n -> f n -> Rep' s (Tup (Replicate n TIx)) -tupRepIdx _ _ SZ _ = () -tupRepIdx p uncons (SS n) tup = +tupRepIdx :: (forall m. f (S m) -> (f m, Int)) + -> SNat n -> f n -> Rep (Tup (Replicate n TIx)) +tupRepIdx _ SZ _ = () +tupRepIdx uncons (SS n) tup = let (tup', i) = uncons tup - in (tupRepIdx p uncons n tup', fromIntegral @Int @Int64 i) + in (tupRepIdx uncons n tup', fromIntegral @Int @Int64 i) ixUncons :: Index (S n) -> (Index n, Int) ixUncons (IxCons idx i) = (idx, i) @@ -124,33 +181,6 @@ ixUncons (IxCons idx i) = (idx, i) shUncons :: Shape (S n) -> (Shape n, Int) shUncons (ShCons idx i) = (idx, i) -class NoAccum t where - noAccum :: Proxy s -> Proxy t -> Rep' s t :~: Rep t -instance NoAccum TNil where - noAccum _ _ = Refl -instance (NoAccum a, NoAccum b) => NoAccum (TPair a b) where - noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl -instance (NoAccum a, NoAccum b) => NoAccum (TEither a b) where - noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl -instance NoAccum t => NoAccum (TArr n t) where - noAccum p _ | Refl <- noAccum p (Proxy @t) = Refl -instance NoAccum (TScal t) where - noAccum _ _ = Refl - -unAccum :: Proxy s -> STy t -> Maybe (Dict (NoAccum t)) -unAccum _ STNil = Just Dict -unAccum p (STPair t1 t2) - | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict - | otherwise = Nothing -unAccum p (STEither t1 t2) - | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict - | otherwise = Nothing -unAccum p (STArr _ t) - | Just Dict <- unAccum p t = Just Dict - | otherwise = Nothing -unAccum _ STScal{} = Just Dict -unAccum _ STAccum{} = Nothing - foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a foldl1M _ [] = error "foldl1M: empty list" foldl1M f (tophead : toptail) = foldM f tophead toptail diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs index b6a91df..af7be1e 100644 --- a/src/Interpreter/Accum.hs +++ b/src/Interpreter/Accum.hs @@ -51,6 +51,7 @@ type family Rep' s t where Rep' s TNil = () Rep' s (TPair a b) = (Rep' s a, Rep' s b) Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b) + Rep' s (TMaybe t) = Maybe (Rep' s t) Rep' s (TArr n t) = Array n (Rep' s t) Rep' s (TScal sty) = ScalRep sty Rep' s (TAccum t) = Accum s t @@ -61,16 +62,13 @@ data Accum s t = Accum (STy t) (ForeignPtr ()) tSize :: Proxy s -> STy t -> Rep' s t -> Int tSize p ty x = tSize' p ty (Just x) --- | Passing Nothing as the value means "this is (inside) an array element". -tSize' :: Proxy s -> STy t -> Maybe (Rep' s t) -> Int -tSize' p typ val = case typ of +tSize' :: Proxy s -> STy t -> Int +tSize' p typ = case typ of STNil -> 0 - STPair a b -> tSize' p a (fst <$> val) + tSize' p b (snd <$> val) - STEither a b -> - case val of - Nothing -> 1 + max (tSize' p a Nothing) (tSize' p b Nothing) - Just (Left x) -> 1 + tSize' p a (Just x) -- '1 +' is for runtime sanity checking - Just (Right y) -> 1 + tSize' p b (Just y) -- idem + STPair a b -> tSize' p a + tSize' p b + STEither a b -> 1 + max (tSize' p a) (tSize' p b) + -- Representation of Maybe t is the same as Either () t; the add operation is different, however. + STMaybe t -> tSize' p (STEither STNil t) STArr ndim t -> case val of Nothing -> error "Nested arrays not supported in this implementation" @@ -99,7 +97,7 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> go inarr b (snd val) off1 STEither a b -> do let !(I# off#) = off - case val of + off1 <- case val of Left x -> do let !(I8# tag#) = 0 writeInt8# addr# off# tag# @@ -108,6 +106,11 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> let !(I8# tag#) = 1 writeInt8# addr# off# tag# go inarr b y (off + 1) + if inarr + then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing)) + else return off1 + -- Representation is the same, but add operation is different + STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off STArr _ t | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do @@ -158,6 +161,8 @@ accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> if inarr then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val) else return (off1, val) + -- Representation is the same, but add operation is different + STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off STArr ndim t | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do @@ -219,6 +224,7 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd" + (STMaybe t, _, _) -> _ idx val (STArr rank eltty, _, _) | inarr -> error "Nested arrays" | otherwise -> do diff --git a/src/Interpreter/AccumOld.hs b/src/Interpreter/AccumOld.hs new file mode 100644 index 0000000..af7be1e --- /dev/null +++ b/src/Interpreter/AccumOld.hs @@ -0,0 +1,366 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} +module Interpreter.Accum ( + AcM, + runAcM, + Rep', + Accum, + withAccum, + accumAdd, + inParallel, +) where + +import Control.Concurrent +import Control.Monad (when, forM_) +import Data.Bifunctor (second) +import Data.Proxy +import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) +import Foreign.Storable (sizeOf) +import GHC.Exts +import GHC.Float +import GHC.Int +import GHC.IO (IO(..)) +import GHC.Word +import System.IO.Unsafe (unsafePerformIO) + +import Array +import AST +import Data + + +newtype AcM s a = AcM (IO a) + deriving newtype (Functor, Applicative, Monad) + +runAcM :: (forall s. AcM s a) -> a +runAcM (AcM m) = unsafePerformIO m + +-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined. +type family Rep' s t where + Rep' s TNil = () + Rep' s (TPair a b) = (Rep' s a, Rep' s b) + Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b) + Rep' s (TMaybe t) = Maybe (Rep' s t) + Rep' s (TArr n t) = Array n (Rep' s t) + Rep' s (TScal sty) = ScalRep sty + Rep' s (TAccum t) = Accum s t + +-- | Floats and integers are accumulated; booleans are left as-is. +data Accum s t = Accum (STy t) (ForeignPtr ()) + +tSize :: Proxy s -> STy t -> Rep' s t -> Int +tSize p ty x = tSize' p ty (Just x) + +tSize' :: Proxy s -> STy t -> Int +tSize' p typ = case typ of + STNil -> 0 + STPair a b -> tSize' p a + tSize' p b + STEither a b -> 1 + max (tSize' p a) (tSize' p b) + -- Representation of Maybe t is the same as Either () t; the add operation is different, however. + STMaybe t -> tSize' p (STEither STNil t) + STArr ndim t -> + case val of + Nothing -> error "Nested arrays not supported in this implementation" + Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing + STScal sty -> goScal sty + STAccum{} -> error "Nested accumulators unsupported" + where + goScal :: SScalTy t -> Int + goScal STI32 = 4 + goScal STI64 = 8 + goScal STF32 = 4 + goScal STF64 = 8 + goScal STBool = 1 + +-- | This operation does not commute with 'accumAdd', so it must be used with +-- care. Furthermore it must be used on exactly the same value as tSize was +-- called on. Hence it lives in IO, not in AcM. +accumWrite :: forall s t. Accum s t -> Rep' s t -> IO () +accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> + let + go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int + go inarr ty val off = case ty of + STNil -> return off + STPair a b -> do + off1 <- go inarr a (fst val) off + go inarr b (snd val) off1 + STEither a b -> do + let !(I# off#) = off + off1 <- case val of + Left x -> do + let !(I8# tag#) = 0 + writeInt8# addr# off# tag# + go inarr a x (off + 1) + Right y -> do + let !(I8# tag#) = 1 + writeInt8# addr# off# tag# + go inarr b y (off + 1) + if inarr + then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing)) + else return off1 + -- Representation is the same, but add operation is different + STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off + STArr _ t + | inarr -> error "Nested arrays not supported in this implementation" + | otherwise -> do + off1 <- goShape (arrayShape val) off + let eltsize = tSize' (Proxy @s) t Nothing + n = arraySize val + traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val + return (off1 + eltsize * n) + STScal sty -> goScal sty val off + STAccum{} -> error "Nested accumulators unsupported" + + goShape :: Shape n -> Int -> IO Int + goShape ShNil off = return off + goShape (ShCons sh n) off = do + off1@(I# off1#) <- goShape sh off + let !(I64# n'#) = fromIntegral n + writeInt64# addr# off1# n'# + return (off1 + 8) + + goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int + goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x + goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x + goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x + goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x + goScal STBool b off@(I# off#) = do + let !(I8# i) = fromIntegral (fromEnum b) + off + 1 <$ writeInt8# addr# off# i + + in () <$ go False topty top_value 0 + +accumRead :: forall s t. Accum s t -> AcM s (Rep' s t) +accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> + let + go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t') + go inarr ty off = case ty of + STNil -> return (off, ()) + STPair a b -> do + (off1, x) <- go inarr a off + (off2, y) <- go inarr b off1 + return (off1 + off2, (x, y)) + STEither a b -> do + let !(I# off#) = off + tag <- readInt8 addr# off# + (off1, val) <- case tag of + 0 -> fmap Left <$> go inarr a (off + 1) + 1 -> fmap Right <$> go inarr b (off + 1) + _ -> error "Invalid tag in accum memory" + if inarr + then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val) + else return (off1, val) + -- Representation is the same, but add operation is different + STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off + STArr ndim t + | inarr -> error "Nested arrays not supported in this implementation" + | otherwise -> do + (off1, sh) <- readShape addr# ndim off + let eltsize = tSize' (Proxy @s) t Nothing + n = shapeSize sh + arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) + return (off1 + eltsize * n, arr) + STScal sty -> goScal sty off + STAccum{} -> error "Nested accumulators unsupported" + + goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t') + goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off# + goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off# + goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off# + goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off# + goScal STBool off@(I# off#) = do + i8 <- readInt8 addr# off# + return (off + 1, toEnum (fromIntegral i8)) + + in snd <$> go False topty 0 + +readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n) +readShape _ SZ off = return (off, ShNil) +readShape mbarr (SS ndim) off = do + (off1@(I# off1#), sh) <- readShape mbarr ndim off + n' <- readInt64 mbarr off1# + return (off1 + 8, ShCons sh (fromIntegral n')) + +-- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of +-- the list. +data InvShape n where + IShNil :: InvShape Z + IShCons :: Int -- ^ How many subarrays are there? + -> Int -- ^ What is the size of all subarrays together? + -> InvShape n -- ^ Sub array inverted shape + -> InvShape (S n) + +ishSize :: InvShape n -> Int +ishSize IShNil = 1 +ishSize (IShCons _ sz _) = sz + +invertShape :: forall n. Shape n -> InvShape n +invertShape | Refl <- lemPlusZero @n = flip go IShNil + where + go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m) + go ShNil ish = ish + go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) + +accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s () +accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> + let + go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO () + go inarr ty SZ () val off = () <$ performAdd inarr ty val off + go inarr ty (SS dep) idx val off = case (ty, idx, val) of + (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off + (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off + (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd" + (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off + (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off + (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd" + (STMaybe t, _, _) -> _ idx val + (STArr rank eltty, _, _) + | inarr -> error "Nested arrays" + | otherwise -> do + (off1, ish) <- second invertShape <$> readShape addr# rank off + goArr (SS dep) ish eltty idx val off1 + (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth" + (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth" + (STAccum{}, _, _) -> error "Nested accumulators unsupported" + + goArr :: SNat i' -> InvShape n -> STy t' + -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO () + goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off + goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off + goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do + let i' = fromIntegral @(Rep' s TIx) @Int i + when (i' < 0 || i' >= n) $ + error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")" + goArr depm1 ish t1 idx val (off + i' * ishSize ish) + + performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int + performAddArr arraySz eltty val off = do + let eltsize = tSize' (Proxy @s) eltty Nothing + forM_ [0 .. arraySz - 1] $ \lini -> + performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize) + return (off + arraySz * eltsize) + + performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int + performAdd inarr ty val off = case ty of + STNil -> return off + STPair t1 t2 -> do + off1 <- performAdd inarr t1 (fst val) off + performAdd inarr t2 (snd val) off1 + STEither t1 t2 -> do + let !(I# off#) = off + tag <- readInt8 addr# off# + off1 <- case (val, tag) of + (Left val1, 0) -> performAdd inarr t1 val1 (off + 1) + (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) + _ -> error "accumAdd: Tag mismatch for Either" + if inarr + then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing)) + else return off1 + STArr n ty' + | inarr -> error "Nested array" + | otherwise -> do + (off1, sh) <- readShape addr# n off + performAddArr (shapeSize sh) ty' val off1 + STScal ty' -> performAddScal ty' val off + STAccum{} -> error "Nested accumulators unsupported" + + performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int + performAddScal STI32 (I32# x#) off@(I# off#) + | sizeOf (undefined :: Int) == 4 + = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#)) + | otherwise + = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#)) + performAddScal STI64 (I64# x#) off@(I# off#) + | sizeOf (undefined :: Int) == 8 + = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#)) + | otherwise + = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#)) + performAddScal STF32 x off@(I# off#) = + off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w)) + performAddScal STF64 x off@(I# off#) = + off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w)) + performAddScal STBool _ off = return (off + 1) -- don't do anything with booleans + + casLoop :: Eq w + => (Addr# -> Int# -> IO w) -- ^ read value (from a given byte offset; will get 0#) + -> (Addr# -> w -> w -> IO w) -- ^ CAS value at address (expected -> desired -> IO observed) + -> Addr# -- ^ Address to attempt to modify + -> (w -> w) -- ^ Operation to apply to the value + -> IO () + casLoop readOp casOp addr modify = readOp addr 0# >>= loop + where + loop value = do + value' <- casOp addr value (modify value) + if value == value' + then return () + else loop value' + + in () <$ go False topty top_depth top_index top_value 0 + +withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t) +withAccum ty start fun = do + -- The initial write must happen before any of the adds or reads, so it makes + -- sense to put it in IO together with the allocation, instead of in AcM. + accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start) + ptr <- newForeignPtr finalizerFree buffer + let accum = Accum ty ptr + accumWrite accum start + return accum + b <- fun accum + out <- accumRead accum + return (b, out) + +inParallel :: [AcM s t] -> AcM s [t] +inParallel actions = AcM $ do + mvars <- mapM (\_ -> newEmptyMVar) actions + forM_ (zip actions mvars) $ \(AcM action, var) -> + forkIO $ action >>= putMVar var + mapM takeMVar mvars + +-- | Offset is in bytes. +readInt8 :: Addr# -> Int# -> IO Int8 +readInt32 :: Addr# -> Int# -> IO Int32 +readInt64 :: Addr# -> Int# -> IO Int64 +readWord32 :: Addr# -> Int# -> IO Word32 +readWord64 :: Addr# -> Int# -> IO Word64 +readFloat :: Addr# -> Int# -> IO Float +readDouble :: Addr# -> Int# -> IO Double +readInt8 addr off# = IO $ \s -> case readInt8OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8# val #) +readInt32 addr off# = IO $ \s -> case readInt32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #) +readInt64 addr off# = IO $ \s -> case readInt64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #) +readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #) +readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #) +readFloat addr off# = IO $ \s -> case readFloatOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F# val #) +readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D# val #) + +writeInt8# :: Addr# -> Int# -> Int8# -> IO () +writeInt32# :: Addr# -> Int# -> Int32# -> IO () +writeInt64# :: Addr# -> Int# -> Int64# -> IO () +writeFloat# :: Addr# -> Int# -> Float# -> IO () +writeDouble# :: Addr# -> Int# -> Double# -> IO () +writeInt8# addr off# val = IO $ \s -> (# writeInt8OffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeInt32# addr off# val = IO $ \s -> (# writeInt32OffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeInt64# addr off# val = IO $ \s -> (# writeInt64OffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeFloat# addr off# val = IO $ \s -> (# writeFloatOffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #) + +fetchAddWord# :: Addr# -> Int# -> Word# -> IO () +fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #) + +atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32 +atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64 +atomicCasWord32Addr addr (W32# expected) (W32# desired) = + IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #) +atomicCasWord64Addr addr (W64# expected) (W64# desired) = + IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #) diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 1ded773..7add442 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -2,6 +2,8 @@ {-# LANGUAGE TypeFamilies #-} module Interpreter.Rep where +import Data.IORef +import qualified Data.Vector.Mutable as MV import GHC.TypeError import Array @@ -12,6 +14,22 @@ type family Rep t where Rep TNil = () Rep (TPair a b) = (Rep a, Rep b) Rep (TEither a b) = Either (Rep a) (Rep b) + Rep (TMaybe t) = Maybe (Rep t) Rep (TArr n t) = Array n (Rep t) Rep (TScal sty) = ScalRep sty - Rep (TAccum t) = TypeError (Text "Accumulator in Rep") + Rep (TAccum t) = IORef (RepAc t) + +type family RepAc t where + RepAc TNil = () + RepAc (TPair a b) = (RepAc a, RepAc b) + -- This is annoying when working with values of type 'RepAc t', because + -- failing a pattern match does not generate negative type information. + -- However, it works, saves us from having to defining a LEither type + -- first-class in the type system with + -- Rep (LEither a b) = Maybe (Either a b) + -- and it's not even incorrect, in a way. + RepAc (TMaybe (TEither a b)) = IORef (Maybe (Either (RepAc a) (RepAc b))) + RepAc (TMaybe t) = IORef (Maybe (RepAc t)) + RepAc (TArr n t) = (Shape n, MV.IOVector (RepAc t)) + RepAc (TScal sty) = IORef (ScalRep sty) + RepAc (TAccum t) = TypeError (Text "Nested accumulators") diff --git a/src/Simplify.hs b/src/Simplify.hs index 1640729..3ac68ed 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -68,6 +68,8 @@ simplify' = \case -- TODO: constant folding for operations + -- TODO: accum of zero, plus of zero + EVar _ t i -> EVar ext t i ELet _ a b -> ELet ext (simplify' a) (simplify' b) EPair _ a b -> EPair ext (simplify' a) (simplify' b) @@ -92,6 +94,8 @@ simplify' = \case EOp _ op e -> EOp ext op (simplify' e) EWith e1 e2 -> EWith (simplify' e1) (let ?accumInScope = True in simplify' e2) EAccum i e1 e2 e3 -> EAccum i (simplify' e1) (simplify' e2) (simplify' e3) + EZero t -> EZero t + EPlus t a b -> EPlus t (simplify' a) (simplify' b) EError t s -> EError t s cheapExpr :: Expr x env t -> Bool @@ -129,6 +133,8 @@ hasAdds = \case EOp _ _ e -> hasAdds e EWith a b -> hasAdds a || hasAdds b EAccum _ _ _ _ -> True + EZero _ -> False + EPlus _ a b -> hasAdds a || hasAdds b EError _ _ -> False checkAccumInScope :: SList STy env -> Bool |