summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2023-09-10 21:13:14 +0200
committerTom Smeding <t.j.smeding@uu.nl>2023-09-10 21:13:14 +0200
commit0bf9f5bb8a0873cad2e11faf83519b6e7ccf87d2 (patch)
tree41ddd52b0293319834b2130814414e76de434396
Initial
-rw-r--r--.gitignore1
-rw-r--r--LICENSE24
-rw-r--r--chad-fast.cabal28
-rw-r--r--prelude.cu3
-rw-r--r--src/AST.hs211
-rw-r--r--src/CHAD.hs121
-rw-r--r--src/Compile.hs120
-rw-r--r--src/PreludeCu.hs9
8 files changed, 517 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..c33954f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+dist-newstyle/
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..253d664
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,24 @@
+Copyright (c) 2023 Tom Smeding.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+3. Neither the name of the copyright holder nor the names of its contributors
+ may be used to endorse or promote products derived from this software
+ without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/chad-fast.cabal b/chad-fast.cabal
new file mode 100644
index 0000000..35fa7f8
--- /dev/null
+++ b/chad-fast.cabal
@@ -0,0 +1,28 @@
+cabal-version: 2.2
+name: chad-fast
+synopsis: Fast CHAD
+version: 0.1.0.0
+license: BSD-3-Clause
+license-file: LICENSE
+author: Tom Smeding
+maintainer: tom@tomsmeding.com
+build-type: Simple
+
+library
+ exposed-modules:
+ AST
+ CHAD
+ -- Compile
+ PreludeCu
+ other-modules:
+ build-depends:
+ base >= 4.14 && < 4.19,
+ containers,
+ template-haskell,
+ some
+ hs-source-dirs:
+ src
+ default-language:
+ Haskell2010
+ ghc-options:
+ -Wall -threaded
diff --git a/prelude.cu b/prelude.cu
new file mode 100644
index 0000000..e63ccf8
--- /dev/null
+++ b/prelude.cu
@@ -0,0 +1,3 @@
+#include <utility>
+
+struct Nil {};
diff --git a/src/AST.hs b/src/AST.hs
new file mode 100644
index 0000000..4d642ba
--- /dev/null
+++ b/src/AST.hs
@@ -0,0 +1,211 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE DeriveFoldable #-}
+module AST where
+
+import Data.Kind (Type)
+import Data.Int
+
+
+data Nat = Z | S Nat
+ deriving (Show, Eq, Ord)
+
+data SNat n where
+ SZ :: SNat Z
+ SS :: SNat n -> SNat (S n)
+deriving instance (Show (SNat n))
+
+data Vec n t where
+ VNil :: Vec n t
+ (:<) :: t -> Vec n t -> Vec (S n) t
+deriving instance Show t => Show (Vec n t)
+deriving instance Functor (Vec n)
+deriving instance Foldable (Vec n)
+
+data SList f l where
+ SNil :: SList f '[]
+ SCons :: f a -> SList f l -> SList f (a : l)
+deriving instance (forall a. Show (f a)) => Show (SList f l)
+
+data Ty
+ = TNil
+ | TPair Ty Ty
+ | TArr Nat Ty -- ^ rank, element type
+ | TScal ScalTy
+ | TEVM [Ty] Ty
+ deriving (Show, Eq, Ord)
+
+data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
+ deriving (Show, Eq, Ord)
+
+type STy :: Ty -> Type
+data STy t where
+ STNil :: STy TNil
+ STPair :: STy a -> STy b -> STy (TPair a b)
+ STArr :: SNat n -> STy t -> STy (TArr n t)
+ STScal :: SScalTy t -> STy (TScal t)
+ STEVM :: SList STy env -> STy t -> STy (TEVM env t)
+deriving instance Show (STy t)
+
+data SScalTy t where
+ STI32 :: SScalTy TI32
+ STI64 :: SScalTy TI64
+ STF32 :: SScalTy TF32
+ STF64 :: SScalTy TF64
+ STBool :: SScalTy TBool
+deriving instance Show (SScalTy t)
+
+type TIx = TScal TI64
+
+type Idx :: [Ty] -> Ty -> Type
+data Idx env t where
+ IZ :: Idx (t : env) t
+ IS :: Idx env t -> Idx (a : env) t
+deriving instance Show (Idx env t)
+
+type family ScalRep t where
+ ScalRep TI32 = Int32
+ ScalRep TI64 = Int64
+ ScalRep TF32 = Float
+ ScalRep TF64 = Double
+ ScalRep TBool = Bool
+
+type ConsN :: Nat -> a -> [a] -> [a]
+type family ConsN n x l where
+ ConsN Z x l = l
+ ConsN (S n) x l = x : ConsN n x l
+
+type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type
+data Expr x env t where
+ -- lambda calculus
+ EVar :: x t -> STy t -> Idx env t -> Expr x env t
+ ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t
+
+ -- array operations
+ EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t)
+ EBuild :: x (TArr n t) -> SNat n -> Vec n (Expr x env TIx) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t)
+ EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
+
+ -- expression operations
+ EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
+ EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)
+ EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t
+ EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
+
+ -- EVM operations
+ EMOne :: Idx venv t -> Expr x env t -> Expr x env (TEVM venv TNil)
+deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)
+
+type SOp :: Ty -> Ty -> Type
+data SOp a t where
+ OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
+ OMul :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
+ ONeg :: SScalTy a -> SOp (TScal a) (TScal a)
+ OLt :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
+ OLe :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
+ OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
+ ONot :: SOp (TScal TBool) (TScal TBool)
+deriving instance Show (SOp a t)
+
+opt2 :: SOp a t -> STy t
+opt2 = \case
+ OAdd t -> STScal t
+ OMul t -> STScal t
+ ONeg t -> STScal t
+ OLt _ -> STScal STBool
+ OLe _ -> STScal STBool
+ OEq _ -> STScal STBool
+ ONot -> STScal STBool
+
+typeOf :: Expr x env t -> STy t
+typeOf = \case
+ EVar _ t _ -> t
+ ELet _ _ e -> typeOf e
+ EBuild1 _ _ e -> STArr (SS SZ) (typeOf e)
+ EBuild _ n _ e -> STArr n (typeOf e)
+ EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
+
+ -- expression operations
+ EConst _ t _ -> STScal t
+ EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t
+ EIdx _ e _ | STArr _ t <- typeOf e -> t
+ EOp _ op _ -> opt2 op
+
+ EMOne _ _ -> STEVM _ STNil
+
+unSNat :: SNat n -> Nat
+unSNat SZ = Z
+unSNat (SS n) = S (unSNat n)
+
+unSTy :: STy t -> Ty
+unSTy = \case
+ STNil -> TNil
+ STPair a b -> TPair (unSTy a) (unSTy b)
+ STArr n t -> TArr (unSNat n) (unSTy t)
+ STScal t -> TScal (unSScalTy t)
+ STEVM l t -> TEVM (unSList l) (unSTy t)
+
+unSList :: SList STy env -> [Ty]
+unSList SNil = []
+unSList (SCons t l) = unSTy t : unSList l
+
+unSScalTy :: SScalTy t -> ScalTy
+unSScalTy = \case
+ STI32 -> TI32
+ STI64 -> TI64
+ STF32 -> TF32
+ STF64 -> TF64
+ STBool -> TBool
+
+fromNat :: Nat -> Int
+fromNat Z = 0
+fromNat (S n) = succ (fromNat n)
+
+data env :> env' where
+ WId :: env :> env
+ WSink :: env :> (t : env)
+ WCopy :: env :> env' -> (t : env) :> (t : env')
+ WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3
+deriving instance Show (env :> env')
+
+(.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3
+(.>) = flip WThen
+
+infixr @>
+(@>) :: env :> env' -> Idx env t -> Idx env' t
+WId @> i = i
+WSink @> i = IS i
+WCopy _ @> IZ = IZ
+WCopy w @> (IS i) = IS (w @> i)
+WThen w1 w2 @> i = w2 @> w1 @> i
+
+weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
+weakenExpr w = \case
+ EVar x t i -> EVar x t (w @> i)
+ ELet x rhs body -> ELet x (weakenExpr w rhs) (weakenExpr (WCopy w) body)
+ EBuild1 x e1 e2 -> EBuild1 x (weakenExpr w e1) (weakenExpr (WCopy w) e2)
+ EBuild x n es e -> EBuild x n (weakenVec w es) (weakenExpr (wcopyN n w) e)
+ EFold1 x e1 e2 -> EFold1 x (weakenExpr (WCopy (WCopy w)) e1) (weakenExpr w e2)
+ EConst x t v -> EConst x t v
+ EIdx1 x e1 e2 -> EIdx1 x (weakenExpr w e1) (weakenExpr w e2)
+ EIdx x e1 es -> EIdx x (weakenExpr w e1) (weakenVec w es)
+ EOp x op e -> EOp x op (weakenExpr w e)
+ EMOne i e -> EMOne i (weakenExpr w e)
+
+wcopyN :: SNat n -> env :> env' -> ConsN n TIx env :> ConsN n TIx env'
+wcopyN SZ w = w
+wcopyN (SS n) w = WCopy (wcopyN n w)
+
+weakenVec :: (env :> env') -> Vec n (Expr x env TIx) -> Vec n (Expr x env' TIx)
+weakenVec _ VNil = VNil
+weakenVec w (e :< v) = weakenExpr w e :< weakenVec w v
diff --git a/src/CHAD.hs b/src/CHAD.hs
new file mode 100644
index 0000000..17ee12b
--- /dev/null
+++ b/src/CHAD.hs
@@ -0,0 +1,121 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module CHAD where
+
+import Data.Functor.Const
+
+import AST
+
+
+type Ex = Expr (Const ())
+
+data Bindings f env env' where
+ BTop :: Bindings f env env
+ BPush :: Bindings f env env' -> f env' t -> Bindings f env (t : env')
+deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env')
+
+weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t)
+ -> env1 :> env2 -> Bindings f env1 env'
+ -> (forall env2'. Bindings f env2 env2' -> env' :> env2' -> r) -> r
+weakenBindings _ w BTop k = k BTop w
+weakenBindings wf w (BPush b x) k =
+ weakenBindings wf w b $ \b' w' -> k (BPush b' (wf w' x)) (WCopy w')
+
+sinkWithBindings :: Bindings f env env' -> env :> env'
+sinkWithBindings BTop = WId
+sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b
+
+bconcat :: Bindings f env1 env2 -> Bindings f env2 env3 -> Bindings f env1 env3
+bconcat b1 BTop = b1
+bconcat b1 (BPush b2 x) = BPush (bconcat b1 b2) x
+
+bconcat' :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t)
+ -> Bindings f env env1 -> Bindings f env env2 -> (forall env12. Bindings env env12 -> r) -> r
+bconcat' wf b1 b2 = weakenBindings
+-- bconcat :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t)
+-- -> Bindings f env env1 -> Bindings f env env2 -> env :> env'
+-- -> (forall env'12. Bindings f env' env'12 -> r) -> r
+-- bconcat wf BTop b w k = weakenBindings wf w b $ \b' _ -> k b'
+-- bconcat wf (BPush b x) b2 w k =
+-- bconcat wf
+
+type family D1 t where
+ D1 TNil = TNil
+ D1 (TPair a b) = TPair (D1 a) (D1 b)
+ D1 (TArr n t) = TArr n (D1 t)
+ D1 (TScal t) = TScal t
+
+type family D2 t where
+ D2 TNil = TNil
+ D2 (TPair a b) = TPair (D2 a) (D2 b)
+ -- D2 (TArr n t) = _
+ D2 (TScal t) = D2s t
+
+type family D2s t where
+ D2s TI32 = TNil
+ D2s TI64 = TNil
+ D2s TF32 = TScal TF32
+ D2s TF64 = TScal TF64
+ D2s TBool = TNil
+
+type family D1E env where
+ D1E '[] = '[]
+ D1E (t : env) = D1 t : D1E env
+
+type family D2E env where
+ D2E '[] = '[]
+ D2E (t : env) = D2 t : D2E env
+
+data Ret env t =
+ forall env'.
+ Ret (Bindings Ex (D1E env) env')
+ (Ex env' (D1 t))
+ (Ex (D2 t : env') (TEVM (D2E env) TNil))
+deriving instance Show (Ret env t)
+
+d1 :: STy t -> STy (D1 t)
+d1 STNil = STNil
+d1 (STPair a b) = STPair (d1 a) (d1 b)
+d1 (STArr n t) = STArr n (d1 t)
+d1 (STScal t) = STScal t
+d1 STEVM{} = error "EVM not allowed in input program"
+
+d2 :: STy t -> STy (D2 t)
+d2 STNil = STNil
+d2 (STPair a b) = STPair (d2 a) (d2 b)
+d2 STArr{} = error "TODO arrays"
+d2 (STScal t) = case t of
+ STI32 -> STNil
+ STI64 -> STNil
+ STF32 -> STScal STF32
+ STF64 -> STScal STF64
+ STBool -> STNil
+d2 STEVM{} = error "EVM not allowed in input program"
+
+conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
+conv1Idx IZ = IZ
+conv1Idx (IS i) = IS (conv1Idx i)
+
+conv2Idx :: Idx env t -> Idx (D2E env) (D2 t)
+conv2Idx IZ = IZ
+conv2Idx (IS i) = IS (conv2Idx i)
+
+drev :: Ex env t -> Ret env t
+drev = \case
+ EVar _ t i ->
+ Ret BTop
+ (EVar ext (d1 t) (conv1Idx i))
+ (EMOne (conv2Idx i) (EVar ext (d2 t) IZ))
+ ELet _ rhs body ->
+ Ret _
+ _
+ _
+ where
+ ext = Const ()
diff --git a/src/Compile.hs b/src/Compile.hs
new file mode 100644
index 0000000..2fcff5d
--- /dev/null
+++ b/src/Compile.hs
@@ -0,0 +1,120 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE DeriveFunctor #-}
+module Compile where
+
+import Control.Monad (ap)
+import Data.Kind (Type)
+import Data.Map.Strict (Map)
+import qualified Data.Map.Strict as Map
+
+import AST
+
+
+data Body = Body [Stm] Inline -- body, return expr
+ deriving (Show)
+
+data Stm
+ = VarDef String String (Maybe Inline) -- type, name, initialiser
+ | Launch Inline Inline Body -- num blocks, block size, kernel function body
+ deriving (Show)
+
+-- inline cuda expression
+data Inline
+ = IOp Inline String Inline
+ | IUOp String Inline
+ | ILit String
+ | IVar String
+ | ICall Inline [Inline]
+ deriving (Show)
+
+data Target = Host | Device
+ deriving (Show)
+
+data FunDef = FunDef Target String [String] Body -- name, params (full declarations), body
+ deriving (Show)
+
+type Env :: [Ty] -> Type -> Type
+data Env env v where
+ ETop :: Env '[] v
+ EPush :: v -> Env env v -> Env (t : env) v
+
+prj :: Env env v -> Idx env t -> v
+prj = \env idx -> go idx env
+ where go :: Idx env t -> Env env v -> v
+ go IZ (EPush v _) = v
+ go (IS i) (EPush _ env) = go i env
+
+-- generated global function definitions, generated local statements, function typedef cache (name, decl)
+newtype M a = M (Int -> Map Ty (String, String) -> ([FunDef], [Stm], Map Ty (String, String), Int, a))
+ deriving (Functor)
+instance Applicative M where
+ pure x = M (\i m -> ([], [], m, i, x))
+ (<*>) = ap
+instance Monad M where
+ M f >>= g = M (\i m -> let (d1, s1, m1, i1, x) = f i m
+ (d2, s2, m2, i2, y) = let M h = g x in h i1 m1
+ in (d1 <> d2, s1 <> s2, m2, i2, y))
+
+emitFun :: FunDef -> M ()
+emitFun fd = M (\i m -> ([fd], [], m, i, ()))
+
+emitStm :: Stm -> M ()
+emitStm stm = M (\i m -> ([], [stm], m, i, ()))
+
+captureStms :: M a -> M ([Stm], a)
+captureStms (M f) = M (\i m -> let (d, s, m2, i2, x) = f i m
+ in (d, [], m2, i2, (s, x)))
+
+genId :: M Int
+genId = M (\i m -> ([], [], m, i + 1, i))
+
+getTypedef :: Ty -> M (Maybe String)
+getTypedef t = M $ \i m -> ([], [], m, i, fst <$> Map.lookup t m)
+
+putTypedef :: Ty -> String -> String -> M ()
+putTypedef t name decl = M $ \i m -> ([], [], Map.insert t (name, decl) m, i, ())
+
+genName :: String -> M String
+genName s = (\i -> s ++ sep ++ show i ++ suf) <$> genId
+ where (sep, suf) = case reverse s of
+ [] -> ("x", "_")
+ c : _ | c `elem` "0123456789_" -> ("_", "")
+ | otherwise -> ("", "")
+
+-- Function values are returned as a function-pointer-typed expression
+compile :: Target -> Env env String -> Expr x env t -> M Inline
+compile tgt env = \case
+ EVar _ _ i -> pure $ IVar (prj env i)
+ ELet _ rhs e -> do
+ rhsi <- compile tgt env rhs
+ var <- genName "x"
+ rhsty <- writeType (typeOf rhs)
+ emitStm $ VarDef rhsty var (Just rhsi)
+ compile tgt (EPush var env) e
+
+ EBuild1 x k e -> compile tgt env $ EBuild x (SS SZ) (k :< VNil) e
+ EBuild x n k e -> case tgt of
+ Host -> do
+ fname <- genName "buildfun"
+ let n' = fromNat (unSNat n)
+ shapevars = ['s' : show i | i <- [0 .. n' - 1]]
+ emitFun $ FunDef Device fname (map ("int " ++) shapevars) _
+ emitStm $ Launch _ _ _
+ _
+ Device -> _
+
+writeType :: STy t -> M String
+writeType = \case
+ STArr _ t -> (++ "*") <$> writeType t
+ STNil -> pure "Nil"
+ STPair a b -> (\x y -> "std::pair<" ++ x ++ "," ++ y ++ ">") <$> writeType a <*> writeType b
+ STScal t -> case t of
+ STI32 -> pure "int32_t"
+ STI64 -> pure "int64_t"
+ STF32 -> pure "float"
+ STF64 -> pure "double"
+ STBool -> pure "bool"
diff --git a/src/PreludeCu.hs b/src/PreludeCu.hs
new file mode 100644
index 0000000..22909a9
--- /dev/null
+++ b/src/PreludeCu.hs
@@ -0,0 +1,9 @@
+{-# LANGUAGE TemplateHaskell #-}
+module PreludeCu where
+
+import Control.Monad.IO.Class (liftIO)
+import Language.Haskell.TH (Exp(LitE), Lit(StringL))
+
+
+prelude :: String
+prelude = $(LitE . StringL <$> liftIO (readFile "prelude.cu"))