diff options
Diffstat (limited to 'src/CHAD/Fusion/AST.hs')
| -rw-r--r-- | src/CHAD/Fusion/AST.hs | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/src/CHAD/Fusion/AST.hs b/src/CHAD/Fusion/AST.hs new file mode 100644 index 0000000..3cd188a --- /dev/null +++ b/src/CHAD/Fusion/AST.hs @@ -0,0 +1,83 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Fusion.AST where + +import Data.Dependent.Map (DMap) +import Data.Functor.Const +import Data.Kind (Type) +import Numeric.Natural + +import CHAD.AST +import CHAD.AST.Bindings +import CHAD.Data + + +type FEx = Expr FGraph (Const ()) + +type FGraph :: (Ty -> Type) -> [Ty] -> Ty -> Type +data FGraph x env t where + FGraph :: DMap NodeId (Node env) -> Tuple NodeId t -> FGraph (Const ()) env t + +data Node env t where + FFreeVar :: STy t -> Idx env t -> Node env t + FLoop :: SList NodeId args + -> SList STy outs + -> LoopNest args outs + -> Tuple (Idx outs) t + -> Node env t + +data NodeId t = NodeId Natural (STy t) + deriving (Show) + +data Tuple f t where + TupNil :: Tuple f TNil + TupPair :: Tuple f a -> Tuple f b -> Tuple f (TPair a b) + TupSingle :: f t -> Tuple f t +deriving instance (forall a. Show (f a)) => Show (Tuple f t) + +data LoopNest args outs where + Inner :: Bindings Ex args bs + -> SList (Idx (Append bs args)) outs + -> LoopNest args outs + -- this should be able to express a simple nesting of builds and sums. + Layer :: Bindings Ex args bs1 + -> Idx bs1 TIx -- ^ loop width (number of (parallel) iterations) + -> LoopNest (TIx : Append bs1 args) loopouts + -> Partition BuildUp RedSum loopouts mapouts sumouts + -> Bindings Ex (Append sumouts (Append bs1 args)) bs2 + -> SList (Idx (Append bs2 (Append bs1 args))) outs + -> LoopNest args (Append outs mapouts) +deriving instance Show (LoopNest args outs) + +type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type +data Partition f1 f2 ts ts1 ts2 where + PNil :: Partition f1 f2 '[] '[] '[] + Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2 + Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2) +deriving instance (forall t t1. Show (f1 t t1), forall t t2. Show (f2 t t2)) => Show (Partition f1 f2 ts ts1 ts2) + +data BuildUp t t' where + BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t) +deriving instance Show (BuildUp t t') + +data RedSum t t' where + RedSum :: SMTy t -> RedSum t t +deriving instance Show (RedSum t t') + +-- type family Unzip t where +-- Unzip (TPair a b) = TPair (Unzip a) (Unzip b) +-- Unzip (TArr n t) = UnzipA n t + +-- type family UnzipA n t where +-- UnzipA n (TPair a b) = TPair (UnzipA n a) (UnzipA n b) +-- UnzipA n t = TArr n t + +-- data Zipping ut t where +-- ZId :: Zipping t t +-- ZPair :: Zipping ua a -> Zipping ub b -> Zipping (TPair ua ub) (TPair a b) +-- ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b)) +-- deriving instance Show (Zipping ut t) |
