blob: 3cd188a34a5b1eda7dbd10aae3a8a8a973328e71 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
|
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Fusion.AST where
import Data.Dependent.Map (DMap)
import Data.Functor.Const
import Data.Kind (Type)
import Numeric.Natural
import CHAD.AST
import CHAD.AST.Bindings
import CHAD.Data
type FEx = Expr FGraph (Const ())
type FGraph :: (Ty -> Type) -> [Ty] -> Ty -> Type
data FGraph x env t where
FGraph :: DMap NodeId (Node env) -> Tuple NodeId t -> FGraph (Const ()) env t
data Node env t where
FFreeVar :: STy t -> Idx env t -> Node env t
FLoop :: SList NodeId args
-> SList STy outs
-> LoopNest args outs
-> Tuple (Idx outs) t
-> Node env t
data NodeId t = NodeId Natural (STy t)
deriving (Show)
data Tuple f t where
TupNil :: Tuple f TNil
TupPair :: Tuple f a -> Tuple f b -> Tuple f (TPair a b)
TupSingle :: f t -> Tuple f t
deriving instance (forall a. Show (f a)) => Show (Tuple f t)
data LoopNest args outs where
Inner :: Bindings Ex args bs
-> SList (Idx (Append bs args)) outs
-> LoopNest args outs
-- this should be able to express a simple nesting of builds and sums.
Layer :: Bindings Ex args bs1
-> Idx bs1 TIx -- ^ loop width (number of (parallel) iterations)
-> LoopNest (TIx : Append bs1 args) loopouts
-> Partition BuildUp RedSum loopouts mapouts sumouts
-> Bindings Ex (Append sumouts (Append bs1 args)) bs2
-> SList (Idx (Append bs2 (Append bs1 args))) outs
-> LoopNest args (Append outs mapouts)
deriving instance Show (LoopNest args outs)
type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type
data Partition f1 f2 ts ts1 ts2 where
PNil :: Partition f1 f2 '[] '[] '[]
Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2
Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2)
deriving instance (forall t t1. Show (f1 t t1), forall t t2. Show (f2 t t2)) => Show (Partition f1 f2 ts ts1 ts2)
data BuildUp t t' where
BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t)
deriving instance Show (BuildUp t t')
data RedSum t t' where
RedSum :: SMTy t -> RedSum t t
deriving instance Show (RedSum t t')
-- type family Unzip t where
-- Unzip (TPair a b) = TPair (Unzip a) (Unzip b)
-- Unzip (TArr n t) = UnzipA n t
-- type family UnzipA n t where
-- UnzipA n (TPair a b) = TPair (UnzipA n a) (UnzipA n b)
-- UnzipA n t = TArr n t
-- data Zipping ut t where
-- ZId :: Zipping t t
-- ZPair :: Zipping ua a -> Zipping ub b -> Zipping (TPair ua ub) (TPair a b)
-- ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b))
-- deriving instance Show (Zipping ut t)
|