summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/AST.hs94
-rw-r--r--src/AST/Count.hs68
-rw-r--r--src/AST/Pretty.hs62
-rw-r--r--src/CHAD.hs47
-rw-r--r--src/Data.hs9
-rw-r--r--src/Example.hs48
-rw-r--r--src/Simplify.hs30
7 files changed, 210 insertions, 148 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 802ee2a..f389467 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -30,7 +30,7 @@ data Ty
| TEither Ty Ty
| TArr Nat Ty -- ^ rank, element type
| TScal ScalTy
- | TAccum Nat Ty -- ^ rank and element type of the array being accumulated to
+ | TAccum Ty
deriving (Show, Eq, Ord)
data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
@@ -43,7 +43,7 @@ data STy t where
STEither :: STy a -> STy b -> STy (TEither a b)
STArr :: SNat n -> STy t -> STy (TArr n t)
STScal :: SScalTy t -> STy (TScal t)
- STAccum :: SNat n -> STy t -> STy (TAccum n t)
+ STAccum :: STy t -> STy (TAccum t)
deriving instance Show (STy t)
data SScalTy t where
@@ -66,10 +66,23 @@ type family ScalRep t where
ScalRep TF64 = Double
ScalRep TBool = Bool
-type ConsN :: Nat -> a -> [a] -> [a]
-type family ConsN n x l where
- ConsN Z x l = l
- ConsN (S n) x l = x : ConsN n x l
+-- | This index is flipped around from the usual direction: the smallest index
+-- is at the _heart_ of the nesting, not at the outside. The outermost layer
+-- indexes into the _outer_ dimension of the type @t@. This makes indices into
+-- compound structures work properly with coproducts.
+type family AcIdx t i where
+ AcIdx t Z = TNil
+ AcIdx (TPair a b) (S i) = TEither (AcIdx a i) (AcIdx b i)
+ AcIdx (TEither a b) (S i) = TEither (AcIdx a i) (AcIdx b i)
+ AcIdx (TArr Z t) (S i) = AcIdx t i
+ AcIdx (TArr (S n) t) (S i) = TPair TIx (AcIdx (TArr n t) i)
+
+type family AcVal t i where
+ AcVal t Z = t
+ AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i)
+ AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i)
+ AcVal (TArr Z t) (S i) = AcVal t i
+ AcVal (TArr (S n) t) (S i) = AcVal (TArr n t) i
-- General assumption: head of the list (whatever way it is associated) is the
-- inner variable / inner array dimension. In pretty printing, the inner
@@ -91,22 +104,23 @@ data Expr x env t where
-- array operations
EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t)
- EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t)
+ EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
- EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) -- TODO: unused
+ -- EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) -- TODO: unused
-- expression operations
EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t
EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)
- EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t
+ EIdx :: x t -> SNat n -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t
EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx))
EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
-- accumulation effect
- EWith :: Expr x env (TArr n t) -> Expr x (TAccum n t : env) a -> Expr x env (TPair a (TArr n t))
- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil
+ EWith :: Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
+ EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil
+ -- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil
-- partiality
EError :: STy a -> String -> Expr x env a
@@ -117,10 +131,6 @@ type Ex = Expr (Const ())
ext :: Const () a
ext = Const ()
-type family Replicate n x where
- Replicate Z x = '[]
- Replicate (S n) x = x : Replicate n x
-
type family Tup env where
Tup '[] = TNil
Tup (t : ts) = TPair (Tup ts) t
@@ -129,6 +139,14 @@ tTup :: SList STy env -> STy (Tup env)
tTup SNil = STNil
tTup (SCons t ts) = STPair (tTup ts) t
+eTup :: SList (Ex env) list -> Ex env (Tup list)
+eTup SNil = ENil ext
+eTup (e `SCons` es) = EPair ext (eTup es) e
+
+type family InvTup core env where
+ InvTup core '[] = core
+ InvTup core (t : ts) = InvTup (TPair core t) ts
+
type SOp :: Ty -> Ty -> Type
data SOp a t where
OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
@@ -169,17 +187,17 @@ typeOf = \case
EBuild _ n _ e -> STArr n (typeOf e)
EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
EUnit _ e -> STArr SZ (typeOf e)
- EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t
+ -- EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t
EConst _ t _ -> STScal t
EIdx0 _ e | STArr _ t <- typeOf e -> t
EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t
- EIdx _ e _ | STArr _ t <- typeOf e -> t
- -- EShape _ e | STArr n _ <- typeOf e -> _
+ EIdx _ _ e _ | STArr _ t <- typeOf e -> t
+ EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx)
EOp _ op _ -> opt2 op
EWith e1 e2 -> STPair (typeOf e2) (typeOf e1)
- EAccum1 _ _ _ -> STNil
+ EAccum _ _ _ _ -> STNil
EError t _ -> t
@@ -194,7 +212,7 @@ unSTy = \case
STEither a b -> TEither (unSTy a) (unSTy b)
STArr n t -> TArr (unSNat n) (unSTy t)
STScal t -> TScal (unSScalTy t)
- STAccum n t -> TAccum (unSNat n) (unSTy t)
+ STAccum t -> TAccum (unSTy t)
unSList :: SList STy env -> [Ty]
unSList SNil = []
@@ -231,17 +249,18 @@ subst' f w = \case
EInr x t e -> EInr x t (subst' f w e)
ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)
- EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkFN n f) (wcopyN n w) b)
+ EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
EUnit x e -> EUnit x (subst' f w e)
- EReplicate x e -> EReplicate x (subst' f w e)
+ -- EReplicate x e -> EReplicate x (subst' f w e)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (subst' f w e)
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
- EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es)
+ EIdx x n e es -> EIdx x n (subst' f w e) (subst' f w es)
+ EShape x e -> EShape x (subst' f w e)
EOp x op e -> EOp x op (subst' f w e)
EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
- EAccum1 e1 e2 e3 -> EAccum1 (subst' f w e1) (subst' f w e2) (subst' f w e3)
+ EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3)
EError t s -> EError t s
where
sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
@@ -250,28 +269,9 @@ subst' f w = \case
IZ -> EVar x' t (w' @> IZ)
IS i -> f' x' t (WPop w') i
- sinkFN :: SNat n
- -> (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
- -> x t -> STy t -> (ConsN n TIx env' :> env2) -> Idx (ConsN n TIx env) t -> Expr x env2 t
- sinkFN SZ f' x t w' i = f' x t w' i
- sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ)
- sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i
-
weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))
-wsinkN :: SNat n -> env :> ConsN n TIx env
-wsinkN SZ = WId
-wsinkN (SS n) = WSink .> wsinkN n
-
-wcopyN :: SNat n -> env :> env' -> ConsN n TIx env :> ConsN n TIx env'
-wcopyN SZ w = w
-wcopyN (SS n) w = WCopy (wcopyN n w)
-
-wpopN :: SNat n -> ConsN n TIx env :> env' -> env :> env'
-wpopN SZ w = w
-wpopN (SS n) w = wpopN n (WPop w)
-
wUndoSubenv :: Subenv env env' -> env' :> env
wUndoSubenv SETop = WId
wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub)
@@ -299,11 +299,15 @@ instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair kn
instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
-instance (KnownNat n, KnownTy t) => KnownTy (TAccum n t) where knownTy = STAccum knownNat knownTy
+instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy
class KnownEnv env where knownEnv :: SList STy env
instance KnownEnv '[] where knownEnv = SNil
instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv
ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t)
-ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) (error "TODO" f)
+ebuildUp1 n sh size f =
+ EBuild ext (SS n) (EPair ext sh size) $
+ let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ
+ in EIdx ext n (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f))
+ (EFst ext arg)
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index a4ff9f2..39d26c2 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -36,6 +36,10 @@ data Occ = Occ { _occLexical :: Count
deriving (Eq, Generic)
deriving (Semigroup, Monoid) via Generically Occ
+instance Show Occ where
+ showsPrec d (Occ l r) = showParen (d > 10) $
+ showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r
+
-- | One of the two branches is taken
(<||>) :: Occ -> Occ -> Occ
Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
@@ -47,9 +51,8 @@ scaleMany (Occ l _) = Occ l Many
occCount :: Idx env a -> Expr x env t -> Occ
occCount idx =
getConst . occCountGeneral
- (\i o -> if idx2int i == idx2int idx then Const o else mempty)
+ (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty)
(\(Const o) -> Const o)
- (\_ (Const o) -> Const o)
(\(Const o1) (Const o2) -> Const (o1 <||> o2))
(\(Const o) -> Const (scaleMany o))
@@ -84,47 +87,48 @@ occEnvPop (OccPush o _) = o
occEnvPop OccEnd = OccEnd
occCountAll :: Expr x env t -> OccEnv env
-occCountAll = occCountGeneral onehotOccEnv occEnvPop occEnvPopN (<||>!) scaleManyOccEnv
- where
- occEnvPopN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env
- occEnvPopN _ OccEnd = OccEnd
- occEnvPopN SZ e = e
- occEnvPopN (SS n) (OccPush e _) = occEnvPopN n e
+occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv
occCountGeneral :: forall r env t x.
(forall env'. Monoid (r env'))
- => (forall env' a. Idx env' a -> Occ -> r env') -- ^ one-hot
+ => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot
-> (forall env' a. r (a : env') -> r env') -- ^ unpush
- -> (forall env' n. SNat n -> r (ConsN n TIx env') -> r env') -- ^ unpushN
-> (forall env'. r env' -> r env' -> r env') -- ^ alternation
-> (forall env'. r env' -> r env') -- ^ scale-many
-> Expr x env t -> r env
-occCountGeneral onehot unpush unpushN alter many = go
+occCountGeneral onehot unpush alter many = go WId
where
- go :: Monoid (r env') => Expr x env' t' -> r env'
- go = \case
- EVar _ _ i -> onehot i (Occ One One)
- ELet _ rhs body -> go rhs <> unpush (go body)
- EPair _ a b -> go a <> go b
- EFst _ e -> go e
- ESnd _ e -> go e
+ go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env'
+ go w = \case
+ EVar _ _ i -> onehot w i (Occ One One)
+ ELet _ rhs body -> re rhs <> re1 body
+ EPair _ a b -> re a <> re b
+ EFst _ e -> re e
+ ESnd _ e -> re e
ENil _ -> mempty
- EInl _ _ e -> go e
- EInr _ _ e -> go e
- ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b))
- EBuild1 _ a b -> go a <> many (unpush (go b))
- EBuild _ n a b -> go a <> many (unpushN n (go b))
- EFold1 _ a b -> many (unpush (unpush (go a))) <> go b
- EUnit _ e -> go e
- EReplicate _ e -> go e
+ EInl _ _ e -> re e
+ EInr _ _ e -> re e
+ ECase _ e a b -> re e <> (re1 a `alter` re1 b)
+ EBuild1 _ a b -> re a <> many (re1 b)
+ EBuild _ _ a b -> re a <> many (re1 b)
+ EFold1 _ a b -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b
+ EUnit _ e -> re e
+ -- EReplicate _ e -> re e
EConst{} -> mempty
- EIdx0 _ e -> go e
- EIdx1 _ a b -> go a <> go b
- EIdx _ e es -> go e <> foldMap go es
- EOp _ _ e -> go e
- EWith a b -> go a <> unpush (go b)
- EAccum1 a b e -> go a <> go b <> go e
+ EIdx0 _ e -> re e
+ EIdx1 _ a b -> re a <> re b
+ EIdx _ _ a b -> re a <> re b
+ EShape _ e -> re e
+ EOp _ _ e -> re e
+ EWith a b -> re a <> re1 b
+ EAccum _ a b e -> re a <> re b <> re e
EError{} -> mempty
+ where
+ re :: Monoid (r env') => Expr x env' t'' -> r env'
+ re = go w
+
+ re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env'
+ re1 = unpush . go (WSink .> w)
deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index dbbc021..5610d36 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -1,16 +1,15 @@
-{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE TupleSections #-}
-module AST.Pretty where
+{-# LANGUAGE TypeOperators #-}
+module AST.Pretty (ppExpr) where
import Control.Monad (ap)
import Data.List (intersperse)
-import Data.Foldable (toList)
import Data.Functor.Const
import AST
@@ -29,10 +28,6 @@ valprj (VPush x _) IZ = x
valprj (VPush _ env) (IS i) = valprj env i
valprj VTop i = case i of {}
-vpushN :: Vec n a -> Val (Const a) env -> Val (Const a) (ConsN n TIx env)
-vpushN VNil v = v
-vpushN (name :< names) v = VPush (Const name) (vpushN names v)
-
newtype M a = M { runM :: Int -> (a, Int) }
deriving (Functor)
instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap }
@@ -115,12 +110,10 @@ ppExpr' d val = \case
EBuild _ n a b -> do
a' <- ppExpr' 11 val a
- names <- sequence (vecGenerate n (\_ -> genName)) -- TODO generate underscores
- e' <- ppExpr' 0 (vpushN names val) b
+ name <- genNameIfUsedIn (tTup (sreplicate n tIx)) IZ b
+ e' <- ppExpr' 0 (VPush (Const name) val) b
return $ showParen (d > 10) $
- showString "build " . a' . showString " (\\["
- . foldr (.) id (intersperse (showString ",") (map showString (reverse (toList names))))
- . showString ("] -> ") . e' . showString ")"
+ showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")"
EFold1 _ a b -> do
name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
@@ -135,9 +128,9 @@ ppExpr' d val = \case
e' <- ppExpr' 11 val e
return $ showParen (d > 10) $ showString "unit " . e'
- EReplicate _ e -> do
- e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString "replicate " . e'
+ -- EReplicate _ e -> do
+ -- e' <- ppExpr' 11 val e
+ -- return $ showParen (d > 10) $ showString "replicate " . e'
EConst _ ty v -> return $ showString $ case ty of
STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v
@@ -151,14 +144,15 @@ ppExpr' d val = \case
b' <- ppExpr' 9 val b
return $ showParen (d > 8) $ a' . showString " ! " . b'
- EIdx _ e es -> do
- e' <- ppExpr' 9 val e
- es' <- traverse (ppExpr' 0 val) es
+ EIdx _ _ a b -> do
+ a' <- ppExpr' 9 val a
+ b' <- ppExpr' 10 val b
return $ showParen (d > 8) $
- e' . showString " ! "
- . showString "["
- . foldr (.) id (intersperse (showString ", ") (reverse (toList es')))
- . showString "]"
+ a' . showString " !! " . b'
+
+ EShape _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ showParen (d > 10) $ showString "shape " . e'
EOp _ op (EPair _ a b)
| (Infix, ops) <- operator op -> do
@@ -175,30 +169,30 @@ ppExpr' d val = \case
EWith e1 e2 -> do
e1' <- ppExpr' 11 val e1
- let STArr n t = typeOf e1
- name <- genNameIfUsedIn' "ac" (STAccum n t) IZ e2
- e2' <- ppExpr' 11 (VPush (Const name) val) e2
+ name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2
+ e2' <- ppExpr' 0 (VPush (Const name) val) e2
return $ showParen (d > 10) $
showString "with " . e1' . showString (" (\\" ++ name ++ " -> ")
. e2' . showString ")"
- EAccum1 e1 e2 e3 -> do
+ EAccum i e1 e2 e3 -> do
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
e3' <- ppExpr' 11 val e3
return $ showParen (d > 10) $
- showString "accum1 " . e1' . showString " " . e2' . showString " " . e3'
+ showString ("accum " ++ show (unSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3'
EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS
ppExprLet d val etop = do
- let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS)
+ let collect :: SVal env -> Expr x env t -> M ([(String, Occ, ShowS)], ShowS)
collect val' (ELet _ rhs body) = do
+ let occ = occCount IZ body
name <- genNameIfUsedIn (typeOf rhs) IZ body
rhs' <- ppExpr' 0 val' rhs
(binds, core) <- collect (VPush (Const name) val') body
- return ((name, rhs') : binds, core)
+ return ((name, occ, rhs') : binds, core)
collect val' e = ([],) <$> ppExpr' 0 val' e
(binds, core) <- collect val etop
@@ -210,7 +204,9 @@ ppExprLet d val etop = do
showString ("let " ++ open)
. foldr (.) id
(intersperse (showString " ; ")
- (map (\(name, rhs) -> showString (name ++ " = ") . rhs) binds))
+ (map (\(name, _occ, rhs) ->
+ showString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") . rhs)
+ binds))
. showString (close ++ " in ")
. core
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 087a26e..692bb96 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -309,9 +309,6 @@ type family D2s t where
D2s TF64 = TScal TF64
D2s TBool = TNil
-type family D2Ac t where
- D2Ac (TArr n t) = TAccum n (D2 t)
-
type family D1E env where
D1E '[] = '[]
D1E (t : env) = D1 t : D1E env
@@ -322,7 +319,7 @@ type family D2E env where
type family D2AcE env where
D2AcE '[] = '[]
- D2AcE (t : env) = D2Ac t : D2AcE env
+ D2AcE (t : env) = TAccum (D2 t) : D2AcE env
-- | Select only the types from the environment that have the specified storage
type family Select env sto s where
@@ -351,16 +348,13 @@ d2 (STScal t) = case t of
STBool -> STNil
d2 STAccum{} = error "Accumulators not allowed in input program"
-d2ac :: STy t -> STy (D2Ac t)
-d2ac (STArr n t) = STAccum n (d2 t)
-d2ac _ = error "Only arrays may appear in the accumulator environment"
-
conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
conv1Idx IZ = IZ
conv1Idx (IS i) = IS (conv1Idx i)
-conv2Idx :: Descr env sto -> Idx env t -> Either (Idx (D2E (Select env sto "accum")) (D2 t))
- (Idx (Select env sto "merge") t)
+conv2Idx :: Descr env sto -> Idx env t
+ -> Either (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t)))
+ (Idx (Select env sto "merge") t)
conv2Idx (DPush _ (_, SAccum)) IZ = Left IZ
conv2Idx (DPush _ (_, SMerge)) IZ = Right IZ
conv2Idx (DPush des (_, SAccum)) (IS i) = first IS (conv2Idx des i)
@@ -371,7 +365,7 @@ zero :: STy t -> Ex env (D2 t)
zero STNil = ENil ext
zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext)
zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext)
-zero (STArr n t) = EBuild ext (vecGenerate n (\_ -> EConst ext STI64 0)) (zero t)
+zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t)
zero (STScal t) = case t of
STI32 -> ENil ext
STI64 -> ENil ext
@@ -464,11 +458,11 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
envpro
prosub
(\shbinds ->
- autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum descr)))
+ autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))
(#acc :++: (#pro :++: #d :++: #shb :++: #tl))
(#pro :++: #d :++: #shb :++: #acc :++: #tl)
.> WCopy (wf shbinds)
- .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum storepl)))
+ .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl)))
(#d :++: #shb :++: #acc :++: #tl)
(#acc :++: (#d :++: #shb :++: #tl)))
@@ -489,7 +483,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
-- goal: | ARE EQUAL ||
-- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
WCopy (wf shbinds)
- .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
+ .> WPick @(TAccum (D2 (TArr arrn arrt))) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
(WId @(D2AcE (Select env1 stoRepl "accum"))))
-- "merge" values must be an array, so reject everything else. (TODO: generalise this)
@@ -505,10 +499,6 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
-- STScal{} -> False
-- STAccum{} -> error "An accumulator in merge storage?"
-type family InvTup core env where
- InvTup core '[] = core
- InvTup core (t : ts) = InvTup (TPair core t) ts
-
makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
makeAccumulators SNil e = e
makeAccumulators (STArr n t `SCons` envpro) e =
@@ -753,7 +743,7 @@ d2e (SCons t ts) = SCons (d2 t) (d2e ts)
d2ace :: SList STy env -> SList STy (D2AcE env)
d2ace SNil = SNil
-d2ace (SCons t ts) = SCons (d2ac t) (d2ace ts)
+d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts)
freezeRet :: Descr env sto
-> Ret env sto t
@@ -775,11 +765,11 @@ drev :: forall env sto t.
drev des = \case
EVar _ t i ->
case conv2Idx des i of
- Left _ ->
+ Left accI ->
Ret BTop
(EVar ext (d1 t) (conv1Idx i))
(subenvNone (select SMerge des))
- (ENil ext)
+ (EAccum SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI)))
Right tupI ->
Ret BTop
@@ -1075,22 +1065,25 @@ drev des = \case
-- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
| Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
<- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil
- ->
- Ret binds
- (EIdx1 ext e1 ei1)
+ , STArr (SS n) eltty <- typeOf e ->
+ Ret (binds `BPush` (tTup (sreplicate (SS n) tIx), EShape ext e1))
+ (weakenExpr WSink (EIdx1 ext e1 ei1))
sub
- (_ e2)
+ (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ)))
+ (EVar ext (STArr n (d2 eltty)) (IS IZ))) $
+ weakenExpr (WCopy (WSink .> WSink)) e2)
-- These should be the next to be implemented, I think
EFold1{} -> err_unsupported "EFold1"
EShape{} -> err_unsupported "EShape"
- EReplicate{} -> err_unsupported "EReplicate"
+ -- EReplicate{} -> err_unsupported "EReplicate"
EIdx{} -> err_unsupported "EIdx"
EBuild{} -> err_unsupported "EBuild"
EWith{} -> err_accum
- EAccum1{} -> err_accum
+ EAccum{} -> err_accum
where
err_accum = error "Accumulator operations unsupported in the source program"
diff --git a/src/Data.hs b/src/Data.hs
index 8c39c6c..eb6c033 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -5,6 +5,7 @@
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data where
@@ -25,6 +26,14 @@ sappend :: SList f l1 -> SList f l2 -> SList f (Append l1 l2)
sappend SNil l = l
sappend (SCons x xs) l = SCons x (sappend xs l)
+type family Replicate n x where
+ Replicate Z x = '[]
+ Replicate (S n) x = x : Replicate n x
+
+sreplicate :: SNat n -> f t -> SList f (Replicate n t)
+sreplicate SZ _ = SNil
+sreplicate (SS n) x = x `SCons` sreplicate n x
+
data Nat = Z | S Nat
deriving (Show, Eq, Ord)
diff --git a/src/Example.hs b/src/Example.hs
index 86264e1..424351c 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
module Example where
import AST
@@ -114,6 +115,9 @@ senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil
descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"]
descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge)
+-- x:R n:I |- let a = unit x
+-- b = build1 n (\i. let c = idx0 a in c * c)
+-- in idx0 (b ! 3)
ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32)
ex6 =
ELet ext (EUnit ext (EVar ext (STScal STF32) (IS IZ))) $
@@ -122,3 +126,47 @@ ex6 =
bin (OMul STF32) (EVar ext (STScal STF32) IZ)
(EVar ext (STScal STF32) IZ)) $
(EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (STScal STF32)) IZ) (EConst ext STI64 3)))
+
+type R = TScal TF32
+
+senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)]
+senv7 =
+ let tR = STScal STF32
+ tpair = STPair tR tR
+ in tR `SCons` STPair (STPair (STPair STNil tpair) tpair) tpair `SCons` SNil
+
+descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"]
+descr7 =
+ let tR = STScal STF32
+ tpair = STPair tR tR
+ in DTop `DPush` (STPair (STPair (STPair STNil tpair) tpair) tpair, SMerge) `DPush` (tR, SMerge)
+
+-- A "neural network" except it's just scalars, not matrices.
+-- ps:((((), (R,R)), (R,R)), (R,R)) x:R
+-- |- let p1 = snd ps
+-- p1' = fst ps
+-- x1 = fst p1 * x + snd p1
+-- p2 = snd p1'
+-- p2' = fst p1'
+-- x2 = fst p2 * x + snd p2
+-- p3 = snd p2'
+-- p3' = fst p2'
+-- x3 = fst p3 * x + snd p3
+-- in x3
+ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R
+ex7 =
+ let tR = STScal STF32
+ tpair = STPair tR tR
+
+ layer :: STy p -> Idx env p -> Idx env R -> Ex env R
+ layer parst@(STPair t (STPair (STScal STF32) (STScal STF32))) pars inp =
+ ELet ext (ESnd ext (EVar ext parst pars)) $
+ ELet ext (EFst ext (EVar ext parst (IS pars))) $
+ ELet ext (bin (OAdd STF32) (bin (OMul STF32) (EFst ext (EVar ext tpair (IS IZ)))
+ (EVar ext tR (IS (IS inp))))
+ (ESnd ext (EVar ext tpair (IS IZ)))) $
+ layer t (IS IZ) IZ
+ layer STNil _ inp = EVar ext tR inp
+ layer _ _ _ = error "Invalid layer inputs"
+
+ in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) (IS IZ) IZ
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 698c667..62a3a9c 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -8,8 +8,6 @@
{-# LANGUAGE TypeOperators #-}
module Simplify where
-import Data.Monoid
-
import AST
import AST.Count
import Data
@@ -45,9 +43,10 @@ simplify' = \case
-- let rotation
ELet _ (ELet _ rhs a) b ->
- ELet ext (simplify' rhs) $
- ELet ext (simplify' a) $
- weakenExpr (WCopy WSink) (simplify' b)
+ simplify' $
+ ELet ext rhs $
+ ELet ext a $
+ weakenExpr (WCopy WSink) (simplify' b)
-- beta rules for products
EFst _ (EPair _ e _) -> simplify' e
@@ -57,6 +56,13 @@ simplify' = \case
ECase _ (EInl _ _ e) rhs _ -> simplify' (ELet ext e rhs)
ECase _ (EInr _ _ e) _ rhs -> simplify' (ELet ext e rhs)
+ -- let floating to facilitate beta reduction
+ EFst _ (ELet _ rhs body) -> simplify' (ELet ext rhs (EFst ext body))
+ ESnd _ (ELet _ rhs body) -> simplify' (ELet ext rhs (ESnd ext body))
+ ECase _ (ELet _ rhs body) e1 e2 -> simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2)))
+ EIdx0 _ (ELet _ rhs body) -> simplify' (ELet ext rhs (EIdx0 ext body))
+ EIdx1 _ (ELet _ rhs body) e -> simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e)))
+
-- TODO: array indexing (index of build, index of fold)
-- TODO: constant folding for operations
@@ -74,14 +80,15 @@ simplify' = \case
EBuild _ n a b -> EBuild ext n (simplify' a) (simplify' b)
EFold1 _ a b -> EFold1 ext (simplify' a) (simplify' b)
EUnit _ e -> EUnit ext (simplify' e)
- EReplicate _ e -> EReplicate ext (simplify' e)
+ -- EReplicate _ e -> EReplicate ext (simplify' e)
EConst _ t v -> EConst ext t v
EIdx0 _ e -> EIdx0 ext (simplify' e)
EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b)
- EIdx _ e es -> EIdx ext (simplify' e) (fmap simplify' es)
+ EIdx _ n a b -> EIdx ext n (simplify' a) (simplify' b)
+ EShape _ e -> EShape ext (simplify' e)
EOp _ op e -> EOp ext op (simplify' e)
EWith e1 e2 -> EWith (simplify' e1) (let ?accumInScope = True in simplify' e2)
- EAccum1 e1 e2 e3 -> EAccum1 (simplify' e1) (simplify' e2) (simplify' e3)
+ EAccum i e1 e2 e3 -> EAccum i (simplify' e1) (simplify' e2) (simplify' e3)
EError t s -> EError t s
cheapExpr :: Expr x env t -> Bool
@@ -108,14 +115,15 @@ hasAdds = \case
EBuild _ _ a b -> hasAdds a || hasAdds b
EFold1 _ a b -> hasAdds a || hasAdds b
EUnit _ e -> hasAdds e
- EReplicate _ e -> hasAdds e
+ -- EReplicate _ e -> hasAdds e
EConst _ _ _ -> False
EIdx0 _ e -> hasAdds e
EIdx1 _ a b -> hasAdds a || hasAdds b
- EIdx _ e es -> hasAdds e || getAny (foldMap (Any . hasAdds) es)
+ EIdx _ _ a b -> hasAdds a || hasAdds b
+ EShape _ e -> hasAdds e
EOp _ _ e -> hasAdds e
EWith a b -> hasAdds a || hasAdds b
- EAccum1 _ _ _ -> True
+ EAccum _ _ _ _ -> True
EError _ _ -> False
checkAccumInScope :: SList STy env -> Bool