diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/AST.hs | 109 | ||||
| -rw-r--r-- | src/AST/Count.hs | 3 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 7 | ||||
| -rw-r--r-- | src/AST/Types.hs | 95 | ||||
| -rw-r--r-- | src/Array.hs | 15 | ||||
| -rw-r--r-- | src/CHAD.hs | 136 | ||||
| -rw-r--r-- | src/CHAD/Types.hs | 65 | ||||
| -rw-r--r-- | src/Interpreter.hs | 142 | ||||
| -rw-r--r-- | src/Interpreter/Accum.hs | 26 | ||||
| -rw-r--r-- | src/Interpreter/AccumOld.hs | 366 | ||||
| -rw-r--r-- | src/Interpreter/Rep.hs | 20 | ||||
| -rw-r--r-- | src/Simplify.hs | 6 | 
12 files changed, 753 insertions, 237 deletions
| @@ -13,92 +13,19 @@  {-# LANGUAGE StandaloneKindSignatures #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-} -module AST (module AST, module AST.Weaken) where +module AST (module AST, module AST.Types, module AST.Weaken) where  import Data.Functor.Const  import Data.Kind (Type) -import Data.Int -import Data.Type.Equality  import Array  import AST.Env +import AST.Types  import AST.Weaken +import CHAD.Types  import Data -data Ty -  = TNil -  | TPair Ty Ty -  | TEither Ty Ty -  | TArr Nat Ty  -- ^ rank, element type -  | TScal ScalTy -  | TAccum Ty -  deriving (Show, Eq, Ord) - -data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool -  deriving (Show, Eq, Ord) - -type STy :: Ty -> Type -data STy t where -  STNil :: STy TNil -  STPair :: STy a -> STy b -> STy (TPair a b) -  STEither :: STy a -> STy b -> STy (TEither a b) -  STArr :: SNat n -> STy t -> STy (TArr n t) -  STScal :: SScalTy t -> STy (TScal t) -  STAccum :: STy t -> STy (TAccum t) -deriving instance Show (STy t) - -instance TestEquality STy where -  testEquality STNil STNil = Just Refl -  testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl -  testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl -  testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl -  testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl -  testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl -  testEquality _ _ = Nothing - -data SScalTy t where -  STI32 :: SScalTy TI32 -  STI64 :: SScalTy TI64 -  STF32 :: SScalTy TF32 -  STF64 :: SScalTy TF64 -  STBool :: SScalTy TBool -deriving instance Show (SScalTy t) - -instance TestEquality SScalTy where -  testEquality STI32 STI32 = Just Refl -  testEquality STI64 STI64 = Just Refl -  testEquality STF32 STF32 = Just Refl -  testEquality STF64 STF64 = Just Refl -  testEquality STBool STBool = Just Refl -  testEquality _ _ = Nothing - -scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) -scalRepIsShow STI32 = Dict -scalRepIsShow STI64 = Dict -scalRepIsShow STF32 = Dict -scalRepIsShow STF64 = Dict -scalRepIsShow STBool = Dict - -type TIx = TScal TI64 - -tIx :: STy TIx -tIx = STScal STI64 - -type family ScalRep t where -  ScalRep TI32 = Int32 -  ScalRep TI64 = Int64 -  ScalRep TF32 = Float -  ScalRep TF64 = Double -  ScalRep TBool = Bool - -type family ScalIsNumeric t where -  ScalIsNumeric TI32 = True -  ScalIsNumeric TI64 = True -  ScalIsNumeric TF32 = True -  ScalIsNumeric TF64 = True -  ScalIsNumeric TBool = False -  -- | This index is flipped around from the usual direction: the smallest index  -- is at the _heart_ of the nesting, not at the outside. The outermost layer  -- indexes into the _outer_ dimension of the type @t@. This makes indices into @@ -107,6 +34,7 @@ type family AcIdx t i where    AcIdx t Z = TNil    AcIdx (TPair a b) (S i) = TEither (AcIdx a i) (AcIdx b i)    AcIdx (TEither a b) (S i) = TEither (AcIdx a i) (AcIdx b i) +  AcIdx (TMaybe t) (S i) = AcIdx t i    AcIdx (TArr Z t) (S i) = AcIdx t i    AcIdx (TArr (S n) t) (S i) = TPair TIx (AcIdx (TArr n t) i) @@ -114,12 +42,20 @@ type family AcVal t i where    AcVal t Z = t    AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i)    AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i) +  AcVal (TMaybe t) (S i) = AcVal t i    AcVal (TArr Z t) (S i) = AcVal t i    AcVal (TArr (S n) t) (S i) = AcVal (TArr n t) i  -- General assumption: head of the list (whatever way it is associated) is the  -- inner variable / inner array dimension. In pretty printing, the inner  -- variable / inner dimension is printed on the _right_. +-- +-- Note that the 'EZero' and 'EPlus' constructs have typing that depend on the +-- type transformation of CHAD. Indeed, these constructors are created _by_ +-- CHAD, and are intended to be eliminated after simplification, so that the +-- input program as well as the output program do not contain these +-- constructors. +-- TODO: ensure this by a "stage" type parameter.  type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type  data Expr x env t where    -- lambda calculus @@ -134,6 +70,9 @@ data Expr x env t where    EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b)    EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b)    ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c +  ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) +  EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) +  EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b    -- array operations    EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) @@ -157,6 +96,10 @@ data Expr x env t where    EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil    -- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil +  -- monoidal operations (to be desugared to regular operations after simplification) +  EZero :: STy t -> Expr x env (D2 t) +  EPlus :: STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t) +    -- partiality    EError :: STy a -> String -> Expr x env a  deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) @@ -220,6 +163,9 @@ typeOf = \case    EInl _ t2 e -> STEither (typeOf e) t2    EInr _ t1 e -> STEither t1 (typeOf e)    ECase _ _ a _ -> typeOf a +  ENothing _ t -> STMaybe t +  EJust _ e -> STMaybe (typeOf e) +  EMaybe _ e _ _ -> typeOf e    EConstArr _ n t _ -> STArr n (STScal t)    EBuild1 _ _ e -> STArr (SS SZ) (typeOf e) @@ -239,6 +185,9 @@ typeOf = \case    EWith e1 e2 -> STPair (typeOf e2) (typeOf e1)    EAccum _ _ _ _ -> STNil +  EZero t -> d2 t +  EPlus t _ _ -> d2 t +    EError t _ -> t  unSNat :: SNat n -> Nat @@ -250,6 +199,7 @@ unSTy = \case    STNil -> TNil    STPair a b -> TPair (unSTy a) (unSTy b)    STEither a b -> TEither (unSTy a) (unSTy b) +  STMaybe t -> TMaybe (unSTy t)    STArr n t -> TArr (unSNat n) (unSTy t)    STScal t -> TScal (unSScalTy t)    STAccum t -> TAccum (unSTy t) @@ -288,6 +238,9 @@ subst' f w = \case    EInl x t e -> EInl x t (subst' f w e)    EInr x t e -> EInr x t (subst' f w e)    ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) +  ENothing x t -> ENothing x t +  EJust x e -> EJust x (subst' f w e) +  EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e)    EConstArr x n t a -> EConstArr x n t a    EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)    EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) @@ -303,6 +256,8 @@ subst' f w = \case    EOp x op e -> EOp x op (subst' f w e)    EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)    EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3) +  EZero t -> EZero t +  EPlus t a b -> EPlus t (subst' f w a) (subst' f w b)    EError t s -> EError t s    where      sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) @@ -339,6 +294,7 @@ class KnownTy t where knownTy :: STy t  instance KnownTy TNil where knownTy = STNil  instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy  instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy +instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy  instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy  instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy  instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy @@ -351,6 +307,7 @@ styKnown :: STy t -> Dict (KnownTy t)  styKnown STNil = Dict  styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict  styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STMaybe t) | Dict <- styKnown t = Dict  styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict  styKnown (STScal t) | Dict <- sscaltyKnown t = Dict  styKnown (STAccum t) | Dict <- styKnown t = Dict diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 6a00e83..40a46f6 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -46,6 +46,7 @@ Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)  -- | This code is executed many times  scaleMany :: Occ -> Occ +scaleMany (Occ l Zero) = Occ l Zero  scaleMany (Occ l _) = Occ l Many  occCount :: Idx env a -> Expr x env t -> Occ @@ -124,6 +125,8 @@ occCountGeneral onehot unpush alter many = go WId        EOp _ _ e -> re e        EWith a b -> re a <> re1 b        EAccum _ a b e -> re a <> re b <> re e +      EZero _ -> mempty +      EPlus _ a b -> re a <> re b        EError{} -> mempty        where          re :: Monoid (r env') => Expr x env' t'' -> r env' diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 2ce883b..f5e681a 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -181,6 +181,13 @@ ppExpr' d val = \case      return $ showParen (d > 10) $        showString ("accum " ++ show (unSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3' +  EZero _ -> return $ showString "zero" + +  EPlus _ a b -> do +    a' <- ppExpr' 11 val a +    b' <- ppExpr' 11 val b +    return $ showParen (d > 10) $ showString "plus " . a' . showString " " . b' +    EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)  ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS diff --git a/src/AST/Types.hs b/src/AST/Types.hs new file mode 100644 index 0000000..a3e5080 --- /dev/null +++ b/src/AST/Types.hs @@ -0,0 +1,95 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeFamilies #-} +module AST.Types where + +import Data.Int (Int32, Int64) +import Data.Kind (Type) +import Data.Type.Equality + +import Data + + +data Ty +  = TNil +  | TPair Ty Ty +  | TEither Ty Ty +  | TMaybe Ty +  | TArr Nat Ty  -- ^ rank, element type +  | TScal ScalTy +  | TAccum Ty +  deriving (Show, Eq, Ord) + +data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool +  deriving (Show, Eq, Ord) + +type STy :: Ty -> Type +data STy t where +  STNil :: STy TNil +  STPair :: STy a -> STy b -> STy (TPair a b) +  STEither :: STy a -> STy b -> STy (TEither a b) +  STMaybe :: STy a -> STy (TMaybe a) +  STArr :: SNat n -> STy t -> STy (TArr n t) +  STScal :: SScalTy t -> STy (TScal t) +  STAccum :: STy t -> STy (TAccum t) +deriving instance Show (STy t) + +instance TestEquality STy where +  testEquality STNil STNil = Just Refl +  testEquality STNil _ = Nothing +  testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl +  testEquality STPair{} _ = Nothing +  testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl +  testEquality STEither{} _ = Nothing +  testEquality (STMaybe a) (STMaybe a') | Just Refl <- testEquality a a' = Just Refl +  testEquality STMaybe{} _ = Nothing +  testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl +  testEquality STArr{} _ = Nothing +  testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl +  testEquality STScal{} _ = Nothing +  testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl +  testEquality STAccum{} _ = Nothing + +data SScalTy t where +  STI32 :: SScalTy TI32 +  STI64 :: SScalTy TI64 +  STF32 :: SScalTy TF32 +  STF64 :: SScalTy TF64 +  STBool :: SScalTy TBool +deriving instance Show (SScalTy t) + +instance TestEquality SScalTy where +  testEquality STI32 STI32 = Just Refl +  testEquality STI64 STI64 = Just Refl +  testEquality STF32 STF32 = Just Refl +  testEquality STF64 STF64 = Just Refl +  testEquality STBool STBool = Just Refl +  testEquality _ _ = Nothing + +scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) +scalRepIsShow STI32 = Dict +scalRepIsShow STI64 = Dict +scalRepIsShow STF32 = Dict +scalRepIsShow STF64 = Dict +scalRepIsShow STBool = Dict + +type TIx = TScal TI64 + +tIx :: STy TIx +tIx = STScal STI64 + +type family ScalRep t where +  ScalRep TI32 = Int32 +  ScalRep TI64 = Int64 +  ScalRep TF32 = Float +  ScalRep TF64 = Double +  ScalRep TBool = Bool + +type family ScalIsNumeric t where +  ScalIsNumeric TI32 = True +  ScalIsNumeric TI64 = True +  ScalIsNumeric TF32 = True +  ScalIsNumeric TF64 = True +  ScalIsNumeric TBool = False diff --git a/src/Array.hs b/src/Array.hs index 9a770c4..0d585a9 100644 --- a/src/Array.hs +++ b/src/Array.hs @@ -17,11 +17,13 @@ data Shape n where    ShNil :: Shape Z    ShCons :: Shape n -> Int -> Shape (S n)  deriving instance Show (Shape n) +deriving instance Eq (Shape n)  data Index n where    IxNil :: Index Z    IxCons :: Index n -> Int -> Index (S n)  deriving instance Show (Index n) +deriving instance Eq (Index n)  shapeSize :: Shape n -> Int  shapeSize ShNil = 0 @@ -38,6 +40,10 @@ toLinearIndex :: Shape n -> Index n -> Int  toLinearIndex ShNil IxNil = 0  toLinearIndex (sh `ShCons` n) (idx `IxCons` i) = toLinearIndex sh idx * n + i +emptyShape :: SNat n -> Shape n +emptyShape SZ = ShNil +emptyShape (SS m) = emptyShape m `ShCons` 0 +  -- | TODO: this Vector is a boxed vector, which is horrendously inefficient.  data Array (n :: Nat) t = Array (Shape n) (Vector t) @@ -49,6 +55,9 @@ arrayShape (Array sh _) = sh  arraySize :: Array n t -> Int  arraySize (Array sh _) = shapeSize sh +emptyArray :: SNat n -> Array n t +emptyArray n = Array (emptyShape n) V.empty +  arrayIndex :: Array n t -> Index n -> t  arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx) @@ -58,6 +67,12 @@ arrayIndexLinear (Array _ v) i = v V.! i  arrayIndex1 :: Array (S n) t -> Int -> Array n t  arrayIndex1 (Array (sh `ShCons` _) v) i = let sz = shapeSize sh in Array sh (V.slice (sz * i) sz v) +arrayGenerate :: Shape n -> (Index n -> t) -> Array n t +arrayGenerate sh f = arrayGenerateLin sh (f . fromLinearIndex sh) + +arrayGenerateLin :: Shape n -> (Int -> t) -> Array n t +arrayGenerateLin sh f = Array sh (V.generate (shapeSize sh) f) +  arrayGenerateM :: Monad m => Shape n -> (Index n -> m t) -> m (Array n t)  arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh) diff --git a/src/CHAD.hs b/src/CHAD.hs index 7747d46..1ab2da0 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -37,6 +37,7 @@ import AST  import AST.Count  import AST.Env  import AST.Weaken.Auto +import CHAD.Types  import Data  import Lemmas @@ -288,66 +289,12 @@ vectorise1Binds env n (bs `BPush` (t, e)) =                         (vectoriseExpr SNil (bindingsBinds bs) env e)    in bs' `BPush` (STArr (SS SZ) t, e') -type family D1 t where -  D1 TNil = TNil -  D1 (TPair a b) = TPair (D1 a) (D1 b) -  D1 (TEither a b) = TEither (D1 a) (D1 b) -  D1 (TArr n t) = TArr n (D1 t) -  D1 (TScal t) = TScal t - -type family D2 t where -  D2 TNil = TNil -  D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b)) -  D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b)) -  D2 (TArr n t) = TArr n (D2 t) -  D2 (TScal t) = D2s t - -type family D2s t where -  D2s TI32 = TNil -  D2s TI64 = TNil -  D2s TF32 = TScal TF32 -  D2s TF64 = TScal TF64 -  D2s TBool = TNil - -type family D1E env where -  D1E '[] = '[] -  D1E (t : env) = D1 t : D1E env - -type family D2E env where -  D2E '[] = '[] -  D2E (t : env) = D2 t : D2E env - -type family D2AcE env where -  D2AcE '[] = '[] -  D2AcE (t : env) = TAccum (D2 t) : D2AcE env -  -- | Select only the types from the environment that have the specified storage  type family Select env sto s where    Select '[] '[] _ = '[]    Select (t : ts) (s : sto) s = t : Select ts sto s    Select (_ : ts) (_ : sto) s = Select ts sto s -d1 :: STy t -> STy (D1 t) -d1 STNil = STNil -d1 (STPair a b) = STPair (d1 a) (d1 b) -d1 (STEither a b) = STEither (d1 a) (d1 b) -d1 (STArr n t) = STArr n (d1 t) -d1 (STScal t) = STScal t -d1 STAccum{} = error "Accumulators not allowed in input program" - -d2 :: STy t -> STy (D2 t) -d2 STNil = STNil -d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b)) -d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b)) -d2 (STArr n t) = STArr n (d2 t) -d2 (STScal t) = case t of -  STI32 -> STNil -  STI64 -> STNil -  STF32 -> STScal STF32 -  STF64 -> STScal STF64 -  STBool -> STNil -d2 STAccum{} = error "Accumulators not allowed in input program" -  conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)  conv1Idx IZ = IZ  conv1Idx (IS i) = IS (conv1Idx i) @@ -362,48 +309,49 @@ conv2Idx (DPush des (_, SMerge)) (IS i) = second IS (conv2Idx des i)  conv2Idx DTop i = case i of {}  zero :: STy t -> Ex env (D2 t) -zero STNil = ENil ext -zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext) -zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext) -zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t) -zero (STScal t) = case t of -  STI32 -> ENil ext -  STI64 -> ENil ext -  STF32 -> EConst ext STF32 0.0 -  STF64 -> EConst ext STF64 0.0 -  STBool -> ENil ext -zero STAccum{} = error "Accumulators not allowed in input program" +zero = EZero +-- TODO: this original definition needs to be used as the post-processing after +-- simplification, to eliminate the monoid operations from the AST +-- zero STNil = ENil ext +-- zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext) +-- zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext) +-- zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t) +-- zero (STScal t) = case t of +--   STI32 -> ENil ext +--   STI64 -> ENil ext +--   STF32 -> EConst ext STF32 0.0 +--   STF64 -> EConst ext STF64 0.0 +--   STBool -> ENil ext +-- zero STAccum{} = error "Accumulators not allowed in input program"  plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t) -plus STNil _ _ = ENil ext -plus (STPair t1 t2) a b = -  let t = STPair (d2 t1) (d2 t2) -  in plusSparse t a b $ -       EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) -                          (EFst ext (EVar ext t IZ))) -                 (plus t2 (ESnd ext (EVar ext t (IS IZ))) -                          (ESnd ext (EVar ext t IZ))) -plus (STEither t1 t2) a b = -  let t = STEither (d2 t1) (d2 t2) -  in plusSparse t a b $ -       ECase ext (EVar ext t (IS IZ)) -         (ECase ext (EVar ext t (IS IZ)) -           (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ))) -           (EError t "plus l+r")) -         (ECase ext (EVar ext t (IS IZ)) -           (EError t "plus r+l") -           (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) -plus STArr{} _ _ = error "TODO plus on arrays" -    -- 'zero' creates an empty array; this should be a new primitive that -    -- (operationally) intelligently memcpy's the non-overlapping part and does -    -- a parallel add on the overlapping part. -plus (STScal t) a b = case t of -  STI32 -> ENil ext -  STI64 -> ENil ext -  STF32 -> EOp ext (OAdd STF32) (EPair ext a b) -  STF64 -> EOp ext (OAdd STF64) (EPair ext a b) -  STBool -> ENil ext -plus STAccum{} _ _ = error "Accumulators not allowed in input program" +plus = EPlus +-- plus STNil _ _ = ENil ext +-- plus (STPair t1 t2) a b = +--   let t = STPair (d2 t1) (d2 t2) +--   in plusSparse t a b $ +--        EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) +--                           (EFst ext (EVar ext t IZ))) +--                  (plus t2 (ESnd ext (EVar ext t (IS IZ))) +--                           (ESnd ext (EVar ext t IZ))) +-- plus (STEither t1 t2) a b = +--   let t = STEither (d2 t1) (d2 t2) +--   in plusSparse t a b $ +--        ECase ext (EVar ext t (IS IZ)) +--          (ECase ext (EVar ext t (IS IZ)) +--            (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ))) +--            (EError t "plus l+r")) +--          (ECase ext (EVar ext t (IS IZ)) +--            (EError t "plus r+l") +--            (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) +-- plus STArr{} _ _ = error "TODO plus on arrays" +-- plus (STScal t) a b = case t of +--   STI32 -> ENil ext +--   STI64 -> ENil ext +--   STF32 -> EOp ext (OAdd STF32) (EPair ext a b) +--   STF64 -> EOp ext (OAdd STF64) (EPair ext a b) +--   STBool -> ENil ext +-- plus STAccum{} _ _ = error "Accumulators not allowed in input program"  plusSparse :: STy a             -> Ex env (TEither TNil a) -> Ex env (TEither TNil a) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs new file mode 100644 index 0000000..0b32393 --- /dev/null +++ b/src/CHAD/Types.hs @@ -0,0 +1,65 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Types where + +import AST.Types + + +type family D1 t where +  D1 TNil = TNil +  D1 (TPair a b) = TPair (D1 a) (D1 b) +  D1 (TEither a b) = TEither (D1 a) (D1 b) +  D1 (TMaybe a) = TMaybe (D1 a) +  D1 (TArr n t) = TArr n (D1 t) +  D1 (TScal t) = TScal t + +type family D2 t where +  D2 TNil = TNil +  D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b)) +  D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b)) +  D2 (TMaybe t) = TMaybe (D2 t) +  D2 (TArr n t) = TArr n (D2 t) +  D2 (TScal t) = D2s t + +type family D2s t where +  D2s TI32 = TNil +  D2s TI64 = TNil +  D2s TF32 = TScal TF32 +  D2s TF64 = TScal TF64 +  D2s TBool = TNil + +type family D1E env where +  D1E '[] = '[] +  D1E (t : env) = D1 t : D1E env + +type family D2E env where +  D2E '[] = '[] +  D2E (t : env) = D2 t : D2E env + +type family D2AcE env where +  D2AcE '[] = '[] +  D2AcE (t : env) = TAccum (D2 t) : D2AcE env + +d1 :: STy t -> STy (D1 t) +d1 STNil = STNil +d1 (STPair a b) = STPair (d1 a) (d1 b) +d1 (STEither a b) = STEither (d1 a) (d1 b) +d1 (STMaybe t) = STMaybe (d1 t) +d1 (STArr n t) = STArr n (d1 t) +d1 (STScal t) = STScal t +d1 STAccum{} = error "Accumulators not allowed in input program" + +d2 :: STy t -> STy (D2 t) +d2 STNil = STNil +d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b)) +d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b)) +d2 (STMaybe t) = STMaybe (d2 t) +d2 (STArr n t) = STArr n (d2 t) +d2 (STScal t) = case t of +  STI32 -> STNil +  STI64 -> STNil +  STF32 -> STScal STF32 +  STF64 -> STScal STF64 +  STBool -> STNil +d2 STAccum{} = error "Accumulators not allowed in input program" diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 7ffb14c..8728ec0 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -1,39 +1,44 @@  {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TypeApplications #-}  {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-}  module Interpreter (    interpret,    interpret',    Value, -  NoAccum(..), -  unAccum,  ) where +import Control.Monad (foldM)  import Data.Int (Int64)  import Data.Proxy +import System.IO.Unsafe (unsafePerformIO) +import Array  import AST +import CHAD.Types  import Data -import Array -import Interpreter.Accum  import Interpreter.Rep -import Control.Monad (foldM) -interpret :: NoAccum t => Ex '[] t -> Rep t -interpret e = runAcM (go e) -  where -    go :: forall s t. NoAccum t => Ex '[] t -> AcM s (Rep t) -    go e' | Refl <- noAccum (Proxy @s) (Proxy @t) = interpret' SNil e' +newtype AcM s a = AcM (IO a) +  deriving newtype (Functor, Applicative, Monad) -newtype Value s t = Value (Rep' s t) +runAcM :: (forall s. AcM s a) -> a +runAcM (AcM m) = unsafePerformIO m -interpret' :: forall env t s. SList (Value s) env -> Ex env t -> AcM s (Rep' s t) +interpret :: Ex '[] t -> Rep t +interpret e = runAcM (interpret' SNil e) + +newtype Value t = Value (Rep t) + +interpret' :: forall env t s. SList Value env -> Ex env t -> AcM s (Rep t)  interpret' env = \case    EVar _ _ i -> case slistIdx env i of Value x -> return x    ELet _ a b -> do @@ -48,14 +53,17 @@ interpret' env = \case    ECase _ e a b -> interpret' env e >>= \case                       Left x -> interpret' (Value x `SCons` env) a                       Right y -> interpret' (Value y `SCons` env) b +  ENothing _ _ -> _ +  EJust _ _ -> _ +  EMaybe _ _ _ _ -> _    EConstArr _ _ _ v -> return v    EBuild1 _ a b -> do      n <- fromIntegral @Int64 @Int <$> interpret' env a      arrayGenerateLinM (ShNil `ShCons` n)                        (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b)    EBuild _ dim a b -> do -    sh <- unTupRepIdx (Proxy @s) ShNil ShCons dim <$> interpret' env a -    arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx (Proxy @s) ixUncons dim idx) `SCons` env) b) +    sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a +    arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b)    EFold1Inner _ a b -> do      let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a      arr <- interpret' env b @@ -75,9 +83,9 @@ interpret' env = \case    EConst _ _ v -> return v    EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e    EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) -  EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx (Proxy @s) IxNil IxCons n <$> interpret' env b) -  EShape _ e | STArr n _ <- typeOf e -> tupRepIdx (Proxy @s) shUncons n . arrayShape <$> interpret' env e -  EOp _ op e -> interpretOp (Proxy @s) op <$> interpret' env e +  EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) +  EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e +  EOp _ op e -> interpretOp op <$> interpret' env e    EWith e1 e2 -> do      initval <- interpret' env e1      withAccum (typeOf e1) initval $ \accum -> @@ -87,10 +95,16 @@ interpret' env = \case      val <- interpret' env e2      accum <- interpret' env e3      accumAdd accum i idx val +  EZero t -> do +    return $ makeZero t +  EPlus t a b -> do +    a' <- interpret' env a +    b' <- interpret' env b +    return $ makePlus t a' b'    EError _ s -> error $ "Interpreter: Program threw error: " ++ s -interpretOp :: Proxy s -> SOp a t -> Rep' s a -> Rep' s t -interpretOp _ op arg = case op of +interpretOp :: SOp a t -> Rep a -> Rep t +interpretOp op arg = case op of    OAdd st -> numericIsNum st $ uncurry (+) arg    OMul st -> numericIsNum st $ uncurry (*) arg    ONeg st -> numericIsNum st $ negate arg @@ -100,23 +114,66 @@ interpretOp _ op arg = case op of    ONot -> not arg    OIf -> if arg then Left () else Right () +makeZero :: STy t -> Rep (D2 t) +makeZero typ = case typ of +  STNil -> () +  STPair _ _ -> Left () +  STEither _ _ -> Left () +  STMaybe _ -> Nothing +  STArr n _ -> emptyArray n +  STScal sty -> case sty of +                  STI32 -> () +                  STI64 -> () +                  STF32 -> 0.0 +                  STF64 -> 0.0 +                  STBool -> () +  STAccum{} -> error "Zero of Accum" + +makePlus :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) +makePlus typ a b = case typ of +  STNil -> () +  STPair t1 t2 -> case (a, b) of +    (Left (), _) -> b +    (_, Left ()) -> a +    (Right (x1, x2), Right (y1, y2)) -> Right (makePlus t1 x1 y1, makePlus t2 x2 y2) +  STEither t1 t2 -> case (a, b) of +    (Left (), _) -> b +    (_, Left ()) -> a +    (Right (Left x), Right (Left y)) -> Right (Left (makePlus t1 x y)) +    (Right (Right x), Right (Right y)) -> Right (Right (makePlus t2 x y)) +    _ -> error "Plus of inconsistent Eithers" +  STArr _ t -> +    let sh1 = arrayShape a +        sh2 = arrayShape b +    in if | shapeSize sh1 == 0 -> b +          | shapeSize sh2 == 0 -> a +          | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> makePlus t (arrayIndexLinear a i) (arrayIndexLinear b i)) +          | otherwise -> error "Plus of inconsistently shaped arrays" +  STScal sty -> case sty of +    STI32 -> () +    STI64 -> () +    STF32 -> a + b +    STF64 -> a + b +    STBool -> () +  STAccum{} -> error "Plus of Accum" +  numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r  numericIsNum STI32 = id  numericIsNum STI64 = id  numericIsNum STF32 = id  numericIsNum STF64 = id -unTupRepIdx :: Proxy s -> f Z -> (forall m. f m -> Int -> f (S m)) -            -> SNat n -> Rep' s (Tup (Replicate n TIx)) -> f n -unTupRepIdx _ nil _    SZ _ = nil -unTupRepIdx p nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i +unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) +            -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n +unTupRepIdx nil _    SZ _ = nil +unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i -tupRepIdx :: Proxy s -> (forall m. f (S m) -> (f m, Int)) -          -> SNat n -> f n -> Rep' s (Tup (Replicate n TIx)) -tupRepIdx _ _      SZ _ = () -tupRepIdx p uncons (SS n) tup = +tupRepIdx :: (forall m. f (S m) -> (f m, Int)) +          -> SNat n -> f n -> Rep (Tup (Replicate n TIx)) +tupRepIdx _      SZ _ = () +tupRepIdx uncons (SS n) tup =    let (tup', i) = uncons tup -  in (tupRepIdx p uncons n tup', fromIntegral @Int @Int64 i) +  in (tupRepIdx uncons n tup', fromIntegral @Int @Int64 i)  ixUncons :: Index (S n) -> (Index n, Int)  ixUncons (IxCons idx i) = (idx, i) @@ -124,33 +181,6 @@ ixUncons (IxCons idx i) = (idx, i)  shUncons :: Shape (S n) -> (Shape n, Int)  shUncons (ShCons idx i) = (idx, i) -class NoAccum t where -  noAccum :: Proxy s -> Proxy t -> Rep' s t :~: Rep t -instance NoAccum TNil where -  noAccum _ _ = Refl -instance (NoAccum a, NoAccum b) => NoAccum (TPair a b) where -  noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl -instance (NoAccum a, NoAccum b) => NoAccum (TEither a b) where -  noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl -instance NoAccum t => NoAccum (TArr n t) where -  noAccum p _ | Refl <- noAccum p (Proxy @t) = Refl -instance NoAccum (TScal t) where -  noAccum _ _ = Refl - -unAccum :: Proxy s -> STy t -> Maybe (Dict (NoAccum t)) -unAccum _ STNil = Just Dict -unAccum p (STPair t1 t2) -  | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict -  | otherwise = Nothing -unAccum p (STEither t1 t2) -  | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict -  | otherwise = Nothing -unAccum p (STArr _ t) -  | Just Dict <- unAccum p t = Just Dict -  | otherwise = Nothing -unAccum _ STScal{} = Just Dict -unAccum _ STAccum{} = Nothing -  foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a  foldl1M _ [] = error "foldl1M: empty list"  foldl1M f (tophead : toptail) = foldM f tophead toptail diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs index b6a91df..af7be1e 100644 --- a/src/Interpreter/Accum.hs +++ b/src/Interpreter/Accum.hs @@ -51,6 +51,7 @@ type family Rep' s t where    Rep' s TNil = ()    Rep' s (TPair a b) = (Rep' s a, Rep' s b)    Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b) +  Rep' s (TMaybe t) = Maybe (Rep' s t)    Rep' s (TArr n t) = Array n (Rep' s t)    Rep' s (TScal sty) = ScalRep sty    Rep' s (TAccum t) = Accum s t @@ -61,16 +62,13 @@ data Accum s t = Accum (STy t) (ForeignPtr ())  tSize :: Proxy s -> STy t -> Rep' s t -> Int  tSize p ty x = tSize' p ty (Just x) --- | Passing Nothing as the value means "this is (inside) an array element". -tSize' :: Proxy s -> STy t -> Maybe (Rep' s t) -> Int -tSize' p typ val = case typ of +tSize' :: Proxy s -> STy t -> Int +tSize' p typ = case typ of    STNil -> 0 -  STPair a b -> tSize' p a (fst <$> val) + tSize' p b (snd <$> val) -  STEither a b -> -    case val of -      Nothing -> 1 + max (tSize' p a Nothing) (tSize' p b Nothing) -      Just (Left x) -> 1 + tSize' p a (Just x)   -- '1 +' is for runtime sanity checking -      Just (Right y) -> 1 + tSize' p b (Just y)  -- idem +  STPair a b -> tSize' p a + tSize' p b +  STEither a b -> 1 + max (tSize' p a) (tSize' p b) +  -- Representation of Maybe t is the same as Either () t; the add operation is different, however. +  STMaybe t -> tSize' p (STEither STNil t)    STArr ndim t ->      case val of        Nothing -> error "Nested arrays not supported in this implementation" @@ -99,7 +97,7 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->          go inarr b (snd val) off1        STEither a b -> do          let !(I# off#) = off -        case val of +        off1 <- case val of            Left x -> do              let !(I8# tag#) = 0              writeInt8# addr# off# tag# @@ -108,6 +106,11 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->              let !(I8# tag#) = 1              writeInt8# addr# off# tag#              go inarr b y (off + 1) +        if inarr +          then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing)) +          else return off1 +      -- Representation is the same, but add operation is different +      STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off        STArr _ t          | inarr -> error "Nested arrays not supported in this implementation"          | otherwise -> do @@ -158,6 +161,8 @@ accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->          if inarr            then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val)            else return (off1, val) +      -- Representation is the same, but add operation is different +      STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t)  off        STArr ndim t          | inarr -> error "Nested arrays not supported in this implementation"          | otherwise -> do @@ -219,6 +224,7 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr        (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off        (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off        (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd" +      (STMaybe t, _, _) -> _ idx val        (STArr rank eltty, _, _)          | inarr -> error "Nested arrays"          | otherwise -> do diff --git a/src/Interpreter/AccumOld.hs b/src/Interpreter/AccumOld.hs new file mode 100644 index 0000000..af7be1e --- /dev/null +++ b/src/Interpreter/AccumOld.hs @@ -0,0 +1,366 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} +module Interpreter.Accum ( +  AcM, +  runAcM, +  Rep', +  Accum, +  withAccum, +  accumAdd, +  inParallel, +) where + +import Control.Concurrent +import Control.Monad (when, forM_) +import Data.Bifunctor (second) +import Data.Proxy +import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) +import Foreign.Storable (sizeOf) +import GHC.Exts +import GHC.Float +import GHC.Int +import GHC.IO (IO(..)) +import GHC.Word +import System.IO.Unsafe (unsafePerformIO) + +import Array +import AST +import Data + + +newtype AcM s a = AcM (IO a) +  deriving newtype (Functor, Applicative, Monad) + +runAcM :: (forall s. AcM s a) -> a +runAcM (AcM m) = unsafePerformIO m + +-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined. +type family Rep' s t where +  Rep' s TNil = () +  Rep' s (TPair a b) = (Rep' s a, Rep' s b) +  Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b) +  Rep' s (TMaybe t) = Maybe (Rep' s t) +  Rep' s (TArr n t) = Array n (Rep' s t) +  Rep' s (TScal sty) = ScalRep sty +  Rep' s (TAccum t) = Accum s t + +-- | Floats and integers are accumulated; booleans are left as-is. +data Accum s t = Accum (STy t) (ForeignPtr ()) + +tSize :: Proxy s -> STy t -> Rep' s t -> Int +tSize p ty x = tSize' p ty (Just x) + +tSize' :: Proxy s -> STy t -> Int +tSize' p typ = case typ of +  STNil -> 0 +  STPair a b -> tSize' p a + tSize' p b +  STEither a b -> 1 + max (tSize' p a) (tSize' p b) +  -- Representation of Maybe t is the same as Either () t; the add operation is different, however. +  STMaybe t -> tSize' p (STEither STNil t) +  STArr ndim t -> +    case val of +      Nothing -> error "Nested arrays not supported in this implementation" +      Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing +  STScal sty -> goScal sty +  STAccum{} -> error "Nested accumulators unsupported" +  where +    goScal :: SScalTy t -> Int +    goScal STI32 = 4 +    goScal STI64 = 8 +    goScal STF32 = 4 +    goScal STF64 = 8 +    goScal STBool = 1 + +-- | This operation does not commute with 'accumAdd', so it must be used with +-- care. Furthermore it must be used on exactly the same value as tSize was +-- called on. Hence it lives in IO, not in AcM. +accumWrite :: forall s t. Accum s t -> Rep' s t -> IO () +accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> +  let +    go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int +    go inarr ty val off = case ty of +      STNil -> return off +      STPair a b -> do +        off1 <- go inarr a (fst val) off +        go inarr b (snd val) off1 +      STEither a b -> do +        let !(I# off#) = off +        off1 <- case val of +          Left x -> do +            let !(I8# tag#) = 0 +            writeInt8# addr# off# tag# +            go inarr a x (off + 1) +          Right y -> do +            let !(I8# tag#) = 1 +            writeInt8# addr# off# tag# +            go inarr b y (off + 1) +        if inarr +          then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing)) +          else return off1 +      -- Representation is the same, but add operation is different +      STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off +      STArr _ t +        | inarr -> error "Nested arrays not supported in this implementation" +        | otherwise -> do +            off1 <- goShape (arrayShape val) off +            let eltsize = tSize' (Proxy @s) t Nothing +                n = arraySize val +            traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val +            return (off1 + eltsize * n) +      STScal sty -> goScal sty val off +      STAccum{} -> error "Nested accumulators unsupported" + +    goShape :: Shape n -> Int -> IO Int +    goShape ShNil off = return off +    goShape (ShCons sh n) off = do +      off1@(I# off1#) <- goShape sh off +      let !(I64# n'#) = fromIntegral n +      writeInt64# addr# off1# n'# +      return (off1 + 8) + +    goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int +    goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x +    goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x +    goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x +    goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x +    goScal STBool b off@(I# off#) = do +      let !(I8# i) = fromIntegral (fromEnum b) +      off + 1 <$ writeInt8# addr# off# i + +  in () <$ go False topty top_value 0 + +accumRead :: forall s t. Accum s t -> AcM s (Rep' s t) +accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> +  let +    go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t') +    go inarr ty off = case ty of +      STNil -> return (off, ()) +      STPair a b -> do +        (off1, x) <- go inarr a off +        (off2, y) <- go inarr b off1 +        return (off1 + off2, (x, y)) +      STEither a b -> do +        let !(I# off#) = off +        tag <- readInt8 addr# off# +        (off1, val) <- case tag of +                         0 -> fmap Left <$> go inarr a (off + 1) +                         1 -> fmap Right <$> go inarr b (off + 1) +                         _ -> error "Invalid tag in accum memory" +        if inarr +          then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val) +          else return (off1, val) +      -- Representation is the same, but add operation is different +      STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t)  off +      STArr ndim t +        | inarr -> error "Nested arrays not supported in this implementation" +        | otherwise -> do +            (off1, sh) <- readShape addr# ndim off +            let eltsize = tSize' (Proxy @s) t Nothing +                n = shapeSize sh +            arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) +            return (off1 + eltsize * n, arr) +      STScal sty -> goScal sty off +      STAccum{} -> error "Nested accumulators unsupported" + +    goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t') +    goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off# +    goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off# +    goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off# +    goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off# +    goScal STBool off@(I# off#) = do +      i8 <- readInt8 addr# off# +      return (off + 1, toEnum (fromIntegral i8)) + +  in snd <$> go False topty 0 + +readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n) +readShape _ SZ off = return (off, ShNil) +readShape mbarr (SS ndim) off = do +  (off1@(I# off1#), sh) <- readShape mbarr ndim off +  n' <- readInt64 mbarr off1# +  return (off1 + 8, ShCons sh (fromIntegral n')) + +-- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of +-- the list. +data InvShape n where +  IShNil :: InvShape Z +  IShCons :: Int  -- ^ How many subarrays are there? +          -> Int  -- ^ What is the size of all subarrays together? +          -> InvShape n  -- ^ Sub array inverted shape +          -> InvShape (S n) + +ishSize :: InvShape n -> Int +ishSize IShNil = 1 +ishSize (IShCons _ sz _) = sz + +invertShape :: forall n. Shape n -> InvShape n +invertShape | Refl <- lemPlusZero @n = flip go IShNil +  where +    go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m) +    go ShNil ish = ish +    go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) + +accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s () +accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> +  let +    go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO () +    go inarr ty SZ () val off = () <$ performAdd inarr ty val off +    go inarr ty (SS dep) idx val off = case (ty, idx, val) of +      (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off +      (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off +      (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd" +      (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off +      (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off +      (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd" +      (STMaybe t, _, _) -> _ idx val +      (STArr rank eltty, _, _) +        | inarr -> error "Nested arrays" +        | otherwise -> do +            (off1, ish) <- second invertShape <$> readShape addr# rank off +            goArr (SS dep) ish eltty idx val off1 +      (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth" +      (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth" +      (STAccum{}, _, _) -> error "Nested accumulators unsupported" + +    goArr :: SNat i' -> InvShape n -> STy t' +          -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO () +    goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off +    goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off +    goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do +      let i' = fromIntegral @(Rep' s TIx) @Int i +      when (i' < 0 || i' >= n) $ +        error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")" +      goArr depm1 ish t1 idx val (off + i' * ishSize ish) + +    performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int +    performAddArr arraySz eltty val off = do +      let eltsize = tSize' (Proxy @s) eltty Nothing +      forM_ [0 .. arraySz - 1] $ \lini -> +        performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize) +      return (off + arraySz * eltsize) + +    performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int +    performAdd inarr ty val off = case ty of +      STNil -> return off +      STPair t1 t2 -> do +        off1 <- performAdd inarr t1 (fst val) off +        performAdd inarr t2 (snd val) off1 +      STEither t1 t2 -> do +        let !(I# off#) = off +        tag <- readInt8 addr# off# +        off1 <- case (val, tag) of +                  (Left val1, 0) -> performAdd inarr t1 val1 (off + 1) +                  (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) +                  _ -> error "accumAdd: Tag mismatch for Either" +        if inarr +          then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing)) +          else return off1 +      STArr n ty' +        | inarr -> error "Nested array" +        | otherwise -> do +            (off1, sh) <- readShape addr# n off +            performAddArr (shapeSize sh) ty' val off1 +      STScal ty' -> performAddScal ty' val off +      STAccum{} -> error "Nested accumulators unsupported" + +    performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int +    performAddScal STI32 (I32# x#) off@(I# off#) +      | sizeOf (undefined :: Int) == 4 +      = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#)) +      | otherwise +      = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#)) +    performAddScal STI64 (I64# x#) off@(I# off#) +      | sizeOf (undefined :: Int) == 8 +      = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#)) +      | otherwise +      = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#)) +    performAddScal STF32 x off@(I# off#) = +      off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w)) +    performAddScal STF64 x off@(I# off#) = +      off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w)) +    performAddScal STBool _ off = return (off + 1)  -- don't do anything with booleans + +    casLoop :: Eq w +            => (Addr# -> Int# -> IO w)    -- ^ read value (from a given byte offset; will get 0#) +            -> (Addr# -> w -> w -> IO w)  -- ^ CAS value at address (expected -> desired -> IO observed) +            -> Addr#                      -- ^ Address to attempt to modify +            -> (w -> w)                   -- ^ Operation to apply to the value +            -> IO () +    casLoop readOp casOp addr modify = readOp addr 0# >>= loop +      where +        loop value = do +          value' <- casOp addr value (modify value) +          if value == value' +            then return () +            else loop value' + +  in () <$ go False topty top_depth top_index top_value 0 + +withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t) +withAccum ty start fun = do +  -- The initial write must happen before any of the adds or reads, so it makes +  -- sense to put it in IO together with the allocation, instead of in AcM. +  accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start) +                    ptr <- newForeignPtr finalizerFree buffer +                    let accum = Accum ty ptr +                    accumWrite accum start +                    return accum +  b <- fun accum +  out <- accumRead accum +  return (b, out) + +inParallel :: [AcM s t] -> AcM s [t] +inParallel actions = AcM $ do +  mvars <- mapM (\_ -> newEmptyMVar) actions +  forM_ (zip actions mvars) $ \(AcM action, var) -> +    forkIO $ action >>= putMVar var +  mapM takeMVar mvars + +-- | Offset is in bytes. +readInt8   :: Addr# -> Int# -> IO Int8 +readInt32  :: Addr# -> Int# -> IO Int32 +readInt64  :: Addr# -> Int# -> IO Int64 +readWord32 :: Addr# -> Int# -> IO Word32 +readWord64 :: Addr# -> Int# -> IO Word64 +readFloat  :: Addr# -> Int# -> IO Float +readDouble :: Addr# -> Int# -> IO Double +readInt8   addr off# = IO $ \s -> case readInt8OffAddr#   (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8#  val #) +readInt32  addr off# = IO $ \s -> case readInt32OffAddr#  (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #) +readInt64  addr off# = IO $ \s -> case readInt64OffAddr#  (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #) +readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #) +readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #) +readFloat  addr off# = IO $ \s -> case readFloatOffAddr#  (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F#   val #) +readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D#   val #) + +writeInt8#   :: Addr# -> Int# -> Int8#   -> IO () +writeInt32#  :: Addr# -> Int# -> Int32#  -> IO () +writeInt64#  :: Addr# -> Int# -> Int64#  -> IO () +writeFloat#  :: Addr# -> Int# -> Float#  -> IO () +writeDouble# :: Addr# -> Int# -> Double# -> IO () +writeInt8#   addr off# val = IO $ \s -> (# writeInt8OffAddr#   (addr `plusAddr#` off#) 0# val s, () #) +writeInt32#  addr off# val = IO $ \s -> (# writeInt32OffAddr#  (addr `plusAddr#` off#) 0# val s, () #) +writeInt64#  addr off# val = IO $ \s -> (# writeInt64OffAddr#  (addr `plusAddr#` off#) 0# val s, () #) +writeFloat#  addr off# val = IO $ \s -> (# writeFloatOffAddr#  (addr `plusAddr#` off#) 0# val s, () #) +writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #) + +fetchAddWord# :: Addr# -> Int# -> Word# -> IO () +fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #) + +atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32 +atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64 +atomicCasWord32Addr addr (W32# expected) (W32# desired) = +  IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #) +atomicCasWord64Addr addr (W64# expected) (W64# desired) = +  IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #) diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 1ded773..7add442 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -2,6 +2,8 @@  {-# LANGUAGE TypeFamilies #-}  module Interpreter.Rep where +import Data.IORef +import qualified Data.Vector.Mutable as MV  import GHC.TypeError  import Array @@ -12,6 +14,22 @@ type family Rep t where    Rep TNil = ()    Rep (TPair a b) = (Rep a, Rep b)    Rep (TEither a b) = Either (Rep a) (Rep b) +  Rep (TMaybe t) = Maybe (Rep t)    Rep (TArr n t) = Array n (Rep t)    Rep (TScal sty) = ScalRep sty -  Rep (TAccum t) = TypeError (Text "Accumulator in Rep") +  Rep (TAccum t) = IORef (RepAc t) + +type family RepAc t where +  RepAc TNil = () +  RepAc (TPair a b) = (RepAc a, RepAc b) +  -- This is annoying when working with values of type 'RepAc t', because +  -- failing a pattern match does not generate negative type information. +  -- However, it works, saves us from having to defining a LEither type +  -- first-class in the type system with +  --   Rep (LEither a b) = Maybe (Either a b) +  -- and it's not even incorrect, in a way. +  RepAc (TMaybe (TEither a b)) = IORef (Maybe (Either (RepAc a) (RepAc b))) +  RepAc (TMaybe t) = IORef (Maybe (RepAc t)) +  RepAc (TArr n t) = (Shape n, MV.IOVector (RepAc t)) +  RepAc (TScal sty) = IORef (ScalRep sty) +  RepAc (TAccum t) = TypeError (Text "Nested accumulators") diff --git a/src/Simplify.hs b/src/Simplify.hs index 1640729..3ac68ed 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -68,6 +68,8 @@ simplify' = \case    -- TODO: constant folding for operations +  -- TODO: accum of zero, plus of zero +    EVar _ t i -> EVar ext t i    ELet _ a b -> ELet ext (simplify' a) (simplify' b)    EPair _ a b -> EPair ext (simplify' a) (simplify' b) @@ -92,6 +94,8 @@ simplify' = \case    EOp _ op e -> EOp ext op (simplify' e)    EWith e1 e2 -> EWith (simplify' e1) (let ?accumInScope = True in simplify' e2)    EAccum i e1 e2 e3 -> EAccum i (simplify' e1) (simplify' e2) (simplify' e3) +  EZero t -> EZero t +  EPlus t a b -> EPlus t (simplify' a) (simplify' b)    EError t s -> EError t s  cheapExpr :: Expr x env t -> Bool @@ -129,6 +133,8 @@ hasAdds = \case    EOp _ _ e -> hasAdds e    EWith a b -> hasAdds a || hasAdds b    EAccum _ _ _ _ -> True +  EZero _ -> False +  EPlus _ a b -> hasAdds a || hasAdds b    EError _ _ -> False  checkAccumInScope :: SList STy env -> Bool | 
