aboutsummaryrefslogtreecommitdiff
path: root/src/CHAD/Fusion/AST.hs
blob: 3cd188a34a5b1eda7dbd10aae3a8a8a973328e71 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Fusion.AST where

import Data.Dependent.Map (DMap)
import Data.Functor.Const
import Data.Kind (Type)
import Numeric.Natural

import CHAD.AST
import CHAD.AST.Bindings
import CHAD.Data


type FEx = Expr FGraph (Const ())

type FGraph :: (Ty -> Type) -> [Ty] -> Ty -> Type
data FGraph x env t where
  FGraph :: DMap NodeId (Node env) -> Tuple NodeId t -> FGraph (Const ()) env t

data Node env t where
  FFreeVar :: STy t -> Idx env t -> Node env t
  FLoop :: SList NodeId args
        -> SList STy outs
        -> LoopNest args outs
        -> Tuple (Idx outs) t
        -> Node env t

data NodeId t = NodeId Natural (STy t)
  deriving (Show)

data Tuple f t where
  TupNil :: Tuple f TNil
  TupPair :: Tuple f a -> Tuple f b -> Tuple f (TPair a b)
  TupSingle :: f t -> Tuple f t
deriving instance (forall a. Show (f a)) => Show (Tuple f t)

data LoopNest args outs where
  Inner :: Bindings Ex args bs
        -> SList (Idx (Append bs args)) outs
        -> LoopNest args outs
  -- this should be able to express a simple nesting of builds and sums.
  Layer :: Bindings Ex args bs1
        -> Idx bs1 TIx  -- ^ loop width (number of (parallel) iterations)
        -> LoopNest (TIx : Append bs1 args) loopouts
        -> Partition BuildUp RedSum loopouts mapouts sumouts
        -> Bindings Ex (Append sumouts (Append bs1 args)) bs2
        -> SList (Idx (Append bs2 (Append bs1 args))) outs
        -> LoopNest args (Append outs mapouts)
deriving instance Show (LoopNest args outs)

type Partition :: (Ty -> Ty -> Type) -> (Ty -> Ty -> Type) -> [Ty] -> [Ty] -> [Ty] -> Type
data Partition f1 f2 ts ts1 ts2 where
  PNil :: Partition f1 f2 '[] '[] '[]
  Part1 :: f1 t t1 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) (t1 : ts1) ts2
  Part2 :: f2 t t2 -> Partition f1 f2 ts ts1 ts2 -> Partition f1 f2 (t : ts) ts1 (t2 : ts2)
deriving instance (forall t t1. Show (f1 t t1), forall t t2. Show (f2 t t2)) => Show (Partition f1 f2 ts ts1 ts2)

data BuildUp t t' where
  BuildUp :: SNat n -> STy t -> BuildUp (TArr n t) (TArr (S n) t)
deriving instance Show (BuildUp t t')

data RedSum t t' where
  RedSum :: SMTy t -> RedSum t t
deriving instance Show (RedSum t t')

-- type family Unzip t where
--   Unzip (TPair a b) = TPair (Unzip a) (Unzip b)
--   Unzip (TArr n t) = UnzipA n t

-- type family UnzipA n t where
--   UnzipA n (TPair a b) = TPair (UnzipA n a) (UnzipA n b)
--   UnzipA n t = TArr n t

-- data Zipping ut t where
--   ZId :: Zipping t t
--   ZPair :: Zipping ua a -> Zipping ub b -> Zipping (TPair ua ub) (TPair a b)
--   ZZip :: Zipping ua (TArr n a) -> Zipping ub (TArr n b) -> Zipping (TPair ua ub) (TArr n (TPair a b))
-- deriving instance Show (Zipping ut t)