diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2023-09-10 21:13:14 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2023-09-10 21:13:14 +0200 |
commit | 0bf9f5bb8a0873cad2e11faf83519b6e7ccf87d2 (patch) | |
tree | 41ddd52b0293319834b2130814414e76de434396 |
Initial
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | LICENSE | 24 | ||||
-rw-r--r-- | chad-fast.cabal | 28 | ||||
-rw-r--r-- | prelude.cu | 3 | ||||
-rw-r--r-- | src/AST.hs | 211 | ||||
-rw-r--r-- | src/CHAD.hs | 121 | ||||
-rw-r--r-- | src/Compile.hs | 120 | ||||
-rw-r--r-- | src/PreludeCu.hs | 9 |
8 files changed, 517 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c33954f --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +dist-newstyle/ @@ -0,0 +1,24 @@ +Copyright (c) 2023 Tom Smeding. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/chad-fast.cabal b/chad-fast.cabal new file mode 100644 index 0000000..35fa7f8 --- /dev/null +++ b/chad-fast.cabal @@ -0,0 +1,28 @@ +cabal-version: 2.2 +name: chad-fast +synopsis: Fast CHAD +version: 0.1.0.0 +license: BSD-3-Clause +license-file: LICENSE +author: Tom Smeding +maintainer: tom@tomsmeding.com +build-type: Simple + +library + exposed-modules: + AST + CHAD + -- Compile + PreludeCu + other-modules: + build-depends: + base >= 4.14 && < 4.19, + containers, + template-haskell, + some + hs-source-dirs: + src + default-language: + Haskell2010 + ghc-options: + -Wall -threaded diff --git a/prelude.cu b/prelude.cu new file mode 100644 index 0000000..e63ccf8 --- /dev/null +++ b/prelude.cu @@ -0,0 +1,3 @@ +#include <utility> + +struct Nil {}; diff --git a/src/AST.hs b/src/AST.hs new file mode 100644 index 0000000..4d642ba --- /dev/null +++ b/src/AST.hs @@ -0,0 +1,211 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveFoldable #-} +module AST where + +import Data.Kind (Type) +import Data.Int + + +data Nat = Z | S Nat + deriving (Show, Eq, Ord) + +data SNat n where + SZ :: SNat Z + SS :: SNat n -> SNat (S n) +deriving instance (Show (SNat n)) + +data Vec n t where + VNil :: Vec n t + (:<) :: t -> Vec n t -> Vec (S n) t +deriving instance Show t => Show (Vec n t) +deriving instance Functor (Vec n) +deriving instance Foldable (Vec n) + +data SList f l where + SNil :: SList f '[] + SCons :: f a -> SList f l -> SList f (a : l) +deriving instance (forall a. Show (f a)) => Show (SList f l) + +data Ty + = TNil + | TPair Ty Ty + | TArr Nat Ty -- ^ rank, element type + | TScal ScalTy + | TEVM [Ty] Ty + deriving (Show, Eq, Ord) + +data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool + deriving (Show, Eq, Ord) + +type STy :: Ty -> Type +data STy t where + STNil :: STy TNil + STPair :: STy a -> STy b -> STy (TPair a b) + STArr :: SNat n -> STy t -> STy (TArr n t) + STScal :: SScalTy t -> STy (TScal t) + STEVM :: SList STy env -> STy t -> STy (TEVM env t) +deriving instance Show (STy t) + +data SScalTy t where + STI32 :: SScalTy TI32 + STI64 :: SScalTy TI64 + STF32 :: SScalTy TF32 + STF64 :: SScalTy TF64 + STBool :: SScalTy TBool +deriving instance Show (SScalTy t) + +type TIx = TScal TI64 + +type Idx :: [Ty] -> Ty -> Type +data Idx env t where + IZ :: Idx (t : env) t + IS :: Idx env t -> Idx (a : env) t +deriving instance Show (Idx env t) + +type family ScalRep t where + ScalRep TI32 = Int32 + ScalRep TI64 = Int64 + ScalRep TF32 = Float + ScalRep TF64 = Double + ScalRep TBool = Bool + +type ConsN :: Nat -> a -> [a] -> [a] +type family ConsN n x l where + ConsN Z x l = l + ConsN (S n) x l = x : ConsN n x l + +type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type +data Expr x env t where + -- lambda calculus + EVar :: x t -> STy t -> Idx env t -> Expr x env t + ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t + + -- array operations + EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t) + EBuild :: x (TArr n t) -> SNat n -> Vec n (Expr x env TIx) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t) + EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) + + -- expression operations + EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) + EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) + EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t + EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t + + -- EVM operations + EMOne :: Idx venv t -> Expr x env t -> Expr x env (TEVM venv TNil) +deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) + +type SOp :: Ty -> Ty -> Type +data SOp a t where + OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + OMul :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + ONeg :: SScalTy a -> SOp (TScal a) (TScal a) + OLt :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + OLe :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + ONot :: SOp (TScal TBool) (TScal TBool) +deriving instance Show (SOp a t) + +opt2 :: SOp a t -> STy t +opt2 = \case + OAdd t -> STScal t + OMul t -> STScal t + ONeg t -> STScal t + OLt _ -> STScal STBool + OLe _ -> STScal STBool + OEq _ -> STScal STBool + ONot -> STScal STBool + +typeOf :: Expr x env t -> STy t +typeOf = \case + EVar _ t _ -> t + ELet _ _ e -> typeOf e + EBuild1 _ _ e -> STArr (SS SZ) (typeOf e) + EBuild _ n _ e -> STArr n (typeOf e) + EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t + + -- expression operations + EConst _ t _ -> STScal t + EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t + EIdx _ e _ | STArr _ t <- typeOf e -> t + EOp _ op _ -> opt2 op + + EMOne _ _ -> STEVM _ STNil + +unSNat :: SNat n -> Nat +unSNat SZ = Z +unSNat (SS n) = S (unSNat n) + +unSTy :: STy t -> Ty +unSTy = \case + STNil -> TNil + STPair a b -> TPair (unSTy a) (unSTy b) + STArr n t -> TArr (unSNat n) (unSTy t) + STScal t -> TScal (unSScalTy t) + STEVM l t -> TEVM (unSList l) (unSTy t) + +unSList :: SList STy env -> [Ty] +unSList SNil = [] +unSList (SCons t l) = unSTy t : unSList l + +unSScalTy :: SScalTy t -> ScalTy +unSScalTy = \case + STI32 -> TI32 + STI64 -> TI64 + STF32 -> TF32 + STF64 -> TF64 + STBool -> TBool + +fromNat :: Nat -> Int +fromNat Z = 0 +fromNat (S n) = succ (fromNat n) + +data env :> env' where + WId :: env :> env + WSink :: env :> (t : env) + WCopy :: env :> env' -> (t : env) :> (t : env') + WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3 +deriving instance Show (env :> env') + +(.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3 +(.>) = flip WThen + +infixr @> +(@>) :: env :> env' -> Idx env t -> Idx env' t +WId @> i = i +WSink @> i = IS i +WCopy _ @> IZ = IZ +WCopy w @> (IS i) = IS (w @> i) +WThen w1 w2 @> i = w2 @> w1 @> i + +weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t +weakenExpr w = \case + EVar x t i -> EVar x t (w @> i) + ELet x rhs body -> ELet x (weakenExpr w rhs) (weakenExpr (WCopy w) body) + EBuild1 x e1 e2 -> EBuild1 x (weakenExpr w e1) (weakenExpr (WCopy w) e2) + EBuild x n es e -> EBuild x n (weakenVec w es) (weakenExpr (wcopyN n w) e) + EFold1 x e1 e2 -> EFold1 x (weakenExpr (WCopy (WCopy w)) e1) (weakenExpr w e2) + EConst x t v -> EConst x t v + EIdx1 x e1 e2 -> EIdx1 x (weakenExpr w e1) (weakenExpr w e2) + EIdx x e1 es -> EIdx x (weakenExpr w e1) (weakenVec w es) + EOp x op e -> EOp x op (weakenExpr w e) + EMOne i e -> EMOne i (weakenExpr w e) + +wcopyN :: SNat n -> env :> env' -> ConsN n TIx env :> ConsN n TIx env' +wcopyN SZ w = w +wcopyN (SS n) w = WCopy (wcopyN n w) + +weakenVec :: (env :> env') -> Vec n (Expr x env TIx) -> Vec n (Expr x env' TIx) +weakenVec _ VNil = VNil +weakenVec w (e :< v) = weakenExpr w e :< weakenVec w v diff --git a/src/CHAD.hs b/src/CHAD.hs new file mode 100644 index 0000000..17ee12b --- /dev/null +++ b/src/CHAD.hs @@ -0,0 +1,121 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD where + +import Data.Functor.Const + +import AST + + +type Ex = Expr (Const ()) + +data Bindings f env env' where + BTop :: Bindings f env env + BPush :: Bindings f env env' -> f env' t -> Bindings f env (t : env') +deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') + +weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) + -> env1 :> env2 -> Bindings f env1 env' + -> (forall env2'. Bindings f env2 env2' -> env' :> env2' -> r) -> r +weakenBindings _ w BTop k = k BTop w +weakenBindings wf w (BPush b x) k = + weakenBindings wf w b $ \b' w' -> k (BPush b' (wf w' x)) (WCopy w') + +sinkWithBindings :: Bindings f env env' -> env :> env' +sinkWithBindings BTop = WId +sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b + +bconcat :: Bindings f env1 env2 -> Bindings f env2 env3 -> Bindings f env1 env3 +bconcat b1 BTop = b1 +bconcat b1 (BPush b2 x) = BPush (bconcat b1 b2) x + +bconcat' :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) + -> Bindings f env env1 -> Bindings f env env2 -> (forall env12. Bindings env env12 -> r) -> r +bconcat' wf b1 b2 = weakenBindings +-- bconcat :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) +-- -> Bindings f env env1 -> Bindings f env env2 -> env :> env' +-- -> (forall env'12. Bindings f env' env'12 -> r) -> r +-- bconcat wf BTop b w k = weakenBindings wf w b $ \b' _ -> k b' +-- bconcat wf (BPush b x) b2 w k = +-- bconcat wf + +type family D1 t where + D1 TNil = TNil + D1 (TPair a b) = TPair (D1 a) (D1 b) + D1 (TArr n t) = TArr n (D1 t) + D1 (TScal t) = TScal t + +type family D2 t where + D2 TNil = TNil + D2 (TPair a b) = TPair (D2 a) (D2 b) + -- D2 (TArr n t) = _ + D2 (TScal t) = D2s t + +type family D2s t where + D2s TI32 = TNil + D2s TI64 = TNil + D2s TF32 = TScal TF32 + D2s TF64 = TScal TF64 + D2s TBool = TNil + +type family D1E env where + D1E '[] = '[] + D1E (t : env) = D1 t : D1E env + +type family D2E env where + D2E '[] = '[] + D2E (t : env) = D2 t : D2E env + +data Ret env t = + forall env'. + Ret (Bindings Ex (D1E env) env') + (Ex env' (D1 t)) + (Ex (D2 t : env') (TEVM (D2E env) TNil)) +deriving instance Show (Ret env t) + +d1 :: STy t -> STy (D1 t) +d1 STNil = STNil +d1 (STPair a b) = STPair (d1 a) (d1 b) +d1 (STArr n t) = STArr n (d1 t) +d1 (STScal t) = STScal t +d1 STEVM{} = error "EVM not allowed in input program" + +d2 :: STy t -> STy (D2 t) +d2 STNil = STNil +d2 (STPair a b) = STPair (d2 a) (d2 b) +d2 STArr{} = error "TODO arrays" +d2 (STScal t) = case t of + STI32 -> STNil + STI64 -> STNil + STF32 -> STScal STF32 + STF64 -> STScal STF64 + STBool -> STNil +d2 STEVM{} = error "EVM not allowed in input program" + +conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) +conv1Idx IZ = IZ +conv1Idx (IS i) = IS (conv1Idx i) + +conv2Idx :: Idx env t -> Idx (D2E env) (D2 t) +conv2Idx IZ = IZ +conv2Idx (IS i) = IS (conv2Idx i) + +drev :: Ex env t -> Ret env t +drev = \case + EVar _ t i -> + Ret BTop + (EVar ext (d1 t) (conv1Idx i)) + (EMOne (conv2Idx i) (EVar ext (d2 t) IZ)) + ELet _ rhs body -> + Ret _ + _ + _ + where + ext = Const () diff --git a/src/Compile.hs b/src/Compile.hs new file mode 100644 index 0000000..2fcff5d --- /dev/null +++ b/src/Compile.hs @@ -0,0 +1,120 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE DeriveFunctor #-} +module Compile where + +import Control.Monad (ap) +import Data.Kind (Type) +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as Map + +import AST + + +data Body = Body [Stm] Inline -- body, return expr + deriving (Show) + +data Stm + = VarDef String String (Maybe Inline) -- type, name, initialiser + | Launch Inline Inline Body -- num blocks, block size, kernel function body + deriving (Show) + +-- inline cuda expression +data Inline + = IOp Inline String Inline + | IUOp String Inline + | ILit String + | IVar String + | ICall Inline [Inline] + deriving (Show) + +data Target = Host | Device + deriving (Show) + +data FunDef = FunDef Target String [String] Body -- name, params (full declarations), body + deriving (Show) + +type Env :: [Ty] -> Type -> Type +data Env env v where + ETop :: Env '[] v + EPush :: v -> Env env v -> Env (t : env) v + +prj :: Env env v -> Idx env t -> v +prj = \env idx -> go idx env + where go :: Idx env t -> Env env v -> v + go IZ (EPush v _) = v + go (IS i) (EPush _ env) = go i env + +-- generated global function definitions, generated local statements, function typedef cache (name, decl) +newtype M a = M (Int -> Map Ty (String, String) -> ([FunDef], [Stm], Map Ty (String, String), Int, a)) + deriving (Functor) +instance Applicative M where + pure x = M (\i m -> ([], [], m, i, x)) + (<*>) = ap +instance Monad M where + M f >>= g = M (\i m -> let (d1, s1, m1, i1, x) = f i m + (d2, s2, m2, i2, y) = let M h = g x in h i1 m1 + in (d1 <> d2, s1 <> s2, m2, i2, y)) + +emitFun :: FunDef -> M () +emitFun fd = M (\i m -> ([fd], [], m, i, ())) + +emitStm :: Stm -> M () +emitStm stm = M (\i m -> ([], [stm], m, i, ())) + +captureStms :: M a -> M ([Stm], a) +captureStms (M f) = M (\i m -> let (d, s, m2, i2, x) = f i m + in (d, [], m2, i2, (s, x))) + +genId :: M Int +genId = M (\i m -> ([], [], m, i + 1, i)) + +getTypedef :: Ty -> M (Maybe String) +getTypedef t = M $ \i m -> ([], [], m, i, fst <$> Map.lookup t m) + +putTypedef :: Ty -> String -> String -> M () +putTypedef t name decl = M $ \i m -> ([], [], Map.insert t (name, decl) m, i, ()) + +genName :: String -> M String +genName s = (\i -> s ++ sep ++ show i ++ suf) <$> genId + where (sep, suf) = case reverse s of + [] -> ("x", "_") + c : _ | c `elem` "0123456789_" -> ("_", "") + | otherwise -> ("", "") + +-- Function values are returned as a function-pointer-typed expression +compile :: Target -> Env env String -> Expr x env t -> M Inline +compile tgt env = \case + EVar _ _ i -> pure $ IVar (prj env i) + ELet _ rhs e -> do + rhsi <- compile tgt env rhs + var <- genName "x" + rhsty <- writeType (typeOf rhs) + emitStm $ VarDef rhsty var (Just rhsi) + compile tgt (EPush var env) e + + EBuild1 x k e -> compile tgt env $ EBuild x (SS SZ) (k :< VNil) e + EBuild x n k e -> case tgt of + Host -> do + fname <- genName "buildfun" + let n' = fromNat (unSNat n) + shapevars = ['s' : show i | i <- [0 .. n' - 1]] + emitFun $ FunDef Device fname (map ("int " ++) shapevars) _ + emitStm $ Launch _ _ _ + _ + Device -> _ + +writeType :: STy t -> M String +writeType = \case + STArr _ t -> (++ "*") <$> writeType t + STNil -> pure "Nil" + STPair a b -> (\x y -> "std::pair<" ++ x ++ "," ++ y ++ ">") <$> writeType a <*> writeType b + STScal t -> case t of + STI32 -> pure "int32_t" + STI64 -> pure "int64_t" + STF32 -> pure "float" + STF64 -> pure "double" + STBool -> pure "bool" diff --git a/src/PreludeCu.hs b/src/PreludeCu.hs new file mode 100644 index 0000000..22909a9 --- /dev/null +++ b/src/PreludeCu.hs @@ -0,0 +1,9 @@ +{-# LANGUAGE TemplateHaskell #-} +module PreludeCu where + +import Control.Monad.IO.Class (liftIO) +import Language.Haskell.TH (Exp(LitE), Lit(StringL)) + + +prelude :: String +prelude = $(LitE . StringL <$> liftIO (readFile "prelude.cu")) |