diff options
-rw-r--r-- | src/AST.hs | 94 | ||||
-rw-r--r-- | src/AST/Count.hs | 68 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 62 | ||||
-rw-r--r-- | src/CHAD.hs | 47 | ||||
-rw-r--r-- | src/Data.hs | 9 | ||||
-rw-r--r-- | src/Example.hs | 48 | ||||
-rw-r--r-- | src/Simplify.hs | 30 |
7 files changed, 210 insertions, 148 deletions
@@ -30,7 +30,7 @@ data Ty | TEither Ty Ty | TArr Nat Ty -- ^ rank, element type | TScal ScalTy - | TAccum Nat Ty -- ^ rank and element type of the array being accumulated to + | TAccum Ty deriving (Show, Eq, Ord) data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool @@ -43,7 +43,7 @@ data STy t where STEither :: STy a -> STy b -> STy (TEither a b) STArr :: SNat n -> STy t -> STy (TArr n t) STScal :: SScalTy t -> STy (TScal t) - STAccum :: SNat n -> STy t -> STy (TAccum n t) + STAccum :: STy t -> STy (TAccum t) deriving instance Show (STy t) data SScalTy t where @@ -66,10 +66,23 @@ type family ScalRep t where ScalRep TF64 = Double ScalRep TBool = Bool -type ConsN :: Nat -> a -> [a] -> [a] -type family ConsN n x l where - ConsN Z x l = l - ConsN (S n) x l = x : ConsN n x l +-- | This index is flipped around from the usual direction: the smallest index +-- is at the _heart_ of the nesting, not at the outside. The outermost layer +-- indexes into the _outer_ dimension of the type @t@. This makes indices into +-- compound structures work properly with coproducts. +type family AcIdx t i where + AcIdx t Z = TNil + AcIdx (TPair a b) (S i) = TEither (AcIdx a i) (AcIdx b i) + AcIdx (TEither a b) (S i) = TEither (AcIdx a i) (AcIdx b i) + AcIdx (TArr Z t) (S i) = AcIdx t i + AcIdx (TArr (S n) t) (S i) = TPair TIx (AcIdx (TArr n t) i) + +type family AcVal t i where + AcVal t Z = t + AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i) + AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i) + AcVal (TArr Z t) (S i) = AcVal t i + AcVal (TArr (S n) t) (S i) = AcVal (TArr n t) i -- General assumption: head of the list (whatever way it is associated) is the -- inner variable / inner array dimension. In pretty printing, the inner @@ -91,22 +104,23 @@ data Expr x env t where -- array operations EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t) - EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t) + EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) - EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) -- TODO: unused + -- EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) -- TODO: unused -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) - EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t + EIdx :: x t -> SNat n -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t -- accumulation effect - EWith :: Expr x env (TArr n t) -> Expr x (TAccum n t : env) a -> Expr x env (TPair a (TArr n t)) - EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil + EWith :: Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) + EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil + -- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil -- partiality EError :: STy a -> String -> Expr x env a @@ -117,10 +131,6 @@ type Ex = Expr (Const ()) ext :: Const () a ext = Const () -type family Replicate n x where - Replicate Z x = '[] - Replicate (S n) x = x : Replicate n x - type family Tup env where Tup '[] = TNil Tup (t : ts) = TPair (Tup ts) t @@ -129,6 +139,14 @@ tTup :: SList STy env -> STy (Tup env) tTup SNil = STNil tTup (SCons t ts) = STPair (tTup ts) t +eTup :: SList (Ex env) list -> Ex env (Tup list) +eTup SNil = ENil ext +eTup (e `SCons` es) = EPair ext (eTup es) e + +type family InvTup core env where + InvTup core '[] = core + InvTup core (t : ts) = InvTup (TPair core t) ts + type SOp :: Ty -> Ty -> Type data SOp a t where OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) @@ -169,17 +187,17 @@ typeOf = \case EBuild _ n _ e -> STArr n (typeOf e) EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t EUnit _ e -> STArr SZ (typeOf e) - EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t + -- EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t - EIdx _ e _ | STArr _ t <- typeOf e -> t - -- EShape _ e | STArr n _ <- typeOf e -> _ + EIdx _ _ e _ | STArr _ t <- typeOf e -> t + EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx) EOp _ op _ -> opt2 op EWith e1 e2 -> STPair (typeOf e2) (typeOf e1) - EAccum1 _ _ _ -> STNil + EAccum _ _ _ _ -> STNil EError t _ -> t @@ -194,7 +212,7 @@ unSTy = \case STEither a b -> TEither (unSTy a) (unSTy b) STArr n t -> TArr (unSNat n) (unSTy t) STScal t -> TScal (unSScalTy t) - STAccum n t -> TAccum (unSNat n) (unSTy t) + STAccum t -> TAccum (unSTy t) unSList :: SList STy env -> [Ty] unSList SNil = [] @@ -231,17 +249,18 @@ subst' f w = \case EInr x t e -> EInr x t (subst' f w e) ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b) - EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkFN n f) (wcopyN n w) b) + EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) EUnit x e -> EUnit x (subst' f w e) - EReplicate x e -> EReplicate x (subst' f w e) + -- EReplicate x e -> EReplicate x (subst' f w e) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) - EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es) + EIdx x n e es -> EIdx x n (subst' f w e) (subst' f w es) + EShape x e -> EShape x (subst' f w e) EOp x op e -> EOp x op (subst' f w e) EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) - EAccum1 e1 e2 e3 -> EAccum1 (subst' f w e1) (subst' f w e2) (subst' f w e3) + EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3) EError t s -> EError t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) @@ -250,28 +269,9 @@ subst' f w = \case IZ -> EVar x' t (w' @> IZ) IS i -> f' x' t (WPop w') i - sinkFN :: SNat n - -> (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) - -> x t -> STy t -> (ConsN n TIx env' :> env2) -> Idx (ConsN n TIx env) t -> Expr x env2 t - sinkFN SZ f' x t w' i = f' x t w' i - sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ) - sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i - weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) -wsinkN :: SNat n -> env :> ConsN n TIx env -wsinkN SZ = WId -wsinkN (SS n) = WSink .> wsinkN n - -wcopyN :: SNat n -> env :> env' -> ConsN n TIx env :> ConsN n TIx env' -wcopyN SZ w = w -wcopyN (SS n) w = WCopy (wcopyN n w) - -wpopN :: SNat n -> ConsN n TIx env :> env' -> env :> env' -wpopN SZ w = w -wpopN (SS n) w = wpopN n (WPop w) - wUndoSubenv :: Subenv env env' -> env' :> env wUndoSubenv SETop = WId wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub) @@ -299,11 +299,15 @@ instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair kn instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy -instance (KnownNat n, KnownTy t) => KnownTy (TAccum n t) where knownTy = STAccum knownNat knownTy +instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy class KnownEnv env where knownEnv :: SList STy env instance KnownEnv '[] where knownEnv = SNil instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) -ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) (error "TODO" f) +ebuildUp1 n sh size f = + EBuild ext (SS n) (EPair ext sh size) $ + let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ + in EIdx ext n (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) + (EFst ext arg) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index a4ff9f2..39d26c2 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -36,6 +36,10 @@ data Occ = Occ { _occLexical :: Count deriving (Eq, Generic) deriving (Semigroup, Monoid) via Generically Occ +instance Show Occ where + showsPrec d (Occ l r) = showParen (d > 10) $ + showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r + -- | One of the two branches is taken (<||>) :: Occ -> Occ -> Occ Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) @@ -47,9 +51,8 @@ scaleMany (Occ l _) = Occ l Many occCount :: Idx env a -> Expr x env t -> Occ occCount idx = getConst . occCountGeneral - (\i o -> if idx2int i == idx2int idx then Const o else mempty) + (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty) (\(Const o) -> Const o) - (\_ (Const o) -> Const o) (\(Const o1) (Const o2) -> Const (o1 <||> o2)) (\(Const o) -> Const (scaleMany o)) @@ -84,47 +87,48 @@ occEnvPop (OccPush o _) = o occEnvPop OccEnd = OccEnd occCountAll :: Expr x env t -> OccEnv env -occCountAll = occCountGeneral onehotOccEnv occEnvPop occEnvPopN (<||>!) scaleManyOccEnv - where - occEnvPopN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env - occEnvPopN _ OccEnd = OccEnd - occEnvPopN SZ e = e - occEnvPopN (SS n) (OccPush e _) = occEnvPopN n e +occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv occCountGeneral :: forall r env t x. (forall env'. Monoid (r env')) - => (forall env' a. Idx env' a -> Occ -> r env') -- ^ one-hot + => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot -> (forall env' a. r (a : env') -> r env') -- ^ unpush - -> (forall env' n. SNat n -> r (ConsN n TIx env') -> r env') -- ^ unpushN -> (forall env'. r env' -> r env' -> r env') -- ^ alternation -> (forall env'. r env' -> r env') -- ^ scale-many -> Expr x env t -> r env -occCountGeneral onehot unpush unpushN alter many = go +occCountGeneral onehot unpush alter many = go WId where - go :: Monoid (r env') => Expr x env' t' -> r env' - go = \case - EVar _ _ i -> onehot i (Occ One One) - ELet _ rhs body -> go rhs <> unpush (go body) - EPair _ a b -> go a <> go b - EFst _ e -> go e - ESnd _ e -> go e + go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env' + go w = \case + EVar _ _ i -> onehot w i (Occ One One) + ELet _ rhs body -> re rhs <> re1 body + EPair _ a b -> re a <> re b + EFst _ e -> re e + ESnd _ e -> re e ENil _ -> mempty - EInl _ _ e -> go e - EInr _ _ e -> go e - ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b)) - EBuild1 _ a b -> go a <> many (unpush (go b)) - EBuild _ n a b -> go a <> many (unpushN n (go b)) - EFold1 _ a b -> many (unpush (unpush (go a))) <> go b - EUnit _ e -> go e - EReplicate _ e -> go e + EInl _ _ e -> re e + EInr _ _ e -> re e + ECase _ e a b -> re e <> (re1 a `alter` re1 b) + EBuild1 _ a b -> re a <> many (re1 b) + EBuild _ _ a b -> re a <> many (re1 b) + EFold1 _ a b -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b + EUnit _ e -> re e + -- EReplicate _ e -> re e EConst{} -> mempty - EIdx0 _ e -> go e - EIdx1 _ a b -> go a <> go b - EIdx _ e es -> go e <> foldMap go es - EOp _ _ e -> go e - EWith a b -> go a <> unpush (go b) - EAccum1 a b e -> go a <> go b <> go e + EIdx0 _ e -> re e + EIdx1 _ a b -> re a <> re b + EIdx _ _ a b -> re a <> re b + EShape _ e -> re e + EOp _ _ e -> re e + EWith a b -> re a <> re1 b + EAccum _ a b e -> re a <> re b <> re e EError{} -> mempty + where + re :: Monoid (r env') => Expr x env' t'' -> r env' + re = go w + + re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env' + re1 = unpush . go (WSink .> w) deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index dbbc021..5610d36 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -1,16 +1,15 @@ -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE TupleSections #-} -module AST.Pretty where +{-# LANGUAGE TypeOperators #-} +module AST.Pretty (ppExpr) where import Control.Monad (ap) import Data.List (intersperse) -import Data.Foldable (toList) import Data.Functor.Const import AST @@ -29,10 +28,6 @@ valprj (VPush x _) IZ = x valprj (VPush _ env) (IS i) = valprj env i valprj VTop i = case i of {} -vpushN :: Vec n a -> Val (Const a) env -> Val (Const a) (ConsN n TIx env) -vpushN VNil v = v -vpushN (name :< names) v = VPush (Const name) (vpushN names v) - newtype M a = M { runM :: Int -> (a, Int) } deriving (Functor) instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap } @@ -115,12 +110,10 @@ ppExpr' d val = \case EBuild _ n a b -> do a' <- ppExpr' 11 val a - names <- sequence (vecGenerate n (\_ -> genName)) -- TODO generate underscores - e' <- ppExpr' 0 (vpushN names val) b + name <- genNameIfUsedIn (tTup (sreplicate n tIx)) IZ b + e' <- ppExpr' 0 (VPush (Const name) val) b return $ showParen (d > 10) $ - showString "build " . a' . showString " (\\[" - . foldr (.) id (intersperse (showString ",") (map showString (reverse (toList names)))) - . showString ("] -> ") . e' . showString ")" + showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")" EFold1 _ a b -> do name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a @@ -135,9 +128,9 @@ ppExpr' d val = \case e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString "unit " . e' - EReplicate _ e -> do - e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString "replicate " . e' + -- EReplicate _ e -> do + -- e' <- ppExpr' 11 val e + -- return $ showParen (d > 10) $ showString "replicate " . e' EConst _ ty v -> return $ showString $ case ty of STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v @@ -151,14 +144,15 @@ ppExpr' d val = \case b' <- ppExpr' 9 val b return $ showParen (d > 8) $ a' . showString " ! " . b' - EIdx _ e es -> do - e' <- ppExpr' 9 val e - es' <- traverse (ppExpr' 0 val) es + EIdx _ _ a b -> do + a' <- ppExpr' 9 val a + b' <- ppExpr' 10 val b return $ showParen (d > 8) $ - e' . showString " ! " - . showString "[" - . foldr (.) id (intersperse (showString ", ") (reverse (toList es'))) - . showString "]" + a' . showString " !! " . b' + + EShape _ e -> do + e' <- ppExpr' 11 val e + return $ showParen (d > 10) $ showString "shape " . e' EOp _ op (EPair _ a b) | (Infix, ops) <- operator op -> do @@ -175,30 +169,30 @@ ppExpr' d val = \case EWith e1 e2 -> do e1' <- ppExpr' 11 val e1 - let STArr n t = typeOf e1 - name <- genNameIfUsedIn' "ac" (STAccum n t) IZ e2 - e2' <- ppExpr' 11 (VPush (Const name) val) e2 + name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2 + e2' <- ppExpr' 0 (VPush (Const name) val) e2 return $ showParen (d > 10) $ showString "with " . e1' . showString (" (\\" ++ name ++ " -> ") . e2' . showString ")" - EAccum1 e1 e2 e3 -> do + EAccum i e1 e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ showParen (d > 10) $ - showString "accum1 " . e1' . showString " " . e2' . showString " " . e3' + showString ("accum " ++ show (unSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3' EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s) ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS ppExprLet d val etop = do - let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS) + let collect :: SVal env -> Expr x env t -> M ([(String, Occ, ShowS)], ShowS) collect val' (ELet _ rhs body) = do + let occ = occCount IZ body name <- genNameIfUsedIn (typeOf rhs) IZ body rhs' <- ppExpr' 0 val' rhs (binds, core) <- collect (VPush (Const name) val') body - return ((name, rhs') : binds, core) + return ((name, occ, rhs') : binds, core) collect val' e = ([],) <$> ppExpr' 0 val' e (binds, core) <- collect val etop @@ -210,7 +204,9 @@ ppExprLet d val etop = do showString ("let " ++ open) . foldr (.) id (intersperse (showString " ; ") - (map (\(name, rhs) -> showString (name ++ " = ") . rhs) binds)) + (map (\(name, _occ, rhs) -> + showString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") . rhs) + binds)) . showString (close ++ " in ") . core diff --git a/src/CHAD.hs b/src/CHAD.hs index 087a26e..692bb96 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -309,9 +309,6 @@ type family D2s t where D2s TF64 = TScal TF64 D2s TBool = TNil -type family D2Ac t where - D2Ac (TArr n t) = TAccum n (D2 t) - type family D1E env where D1E '[] = '[] D1E (t : env) = D1 t : D1E env @@ -322,7 +319,7 @@ type family D2E env where type family D2AcE env where D2AcE '[] = '[] - D2AcE (t : env) = D2Ac t : D2AcE env + D2AcE (t : env) = TAccum (D2 t) : D2AcE env -- | Select only the types from the environment that have the specified storage type family Select env sto s where @@ -351,16 +348,13 @@ d2 (STScal t) = case t of STBool -> STNil d2 STAccum{} = error "Accumulators not allowed in input program" -d2ac :: STy t -> STy (D2Ac t) -d2ac (STArr n t) = STAccum n (d2 t) -d2ac _ = error "Only arrays may appear in the accumulator environment" - conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) conv1Idx IZ = IZ conv1Idx (IS i) = IS (conv1Idx i) -conv2Idx :: Descr env sto -> Idx env t -> Either (Idx (D2E (Select env sto "accum")) (D2 t)) - (Idx (Select env sto "merge") t) +conv2Idx :: Descr env sto -> Idx env t + -> Either (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) + (Idx (Select env sto "merge") t) conv2Idx (DPush _ (_, SAccum)) IZ = Left IZ conv2Idx (DPush _ (_, SMerge)) IZ = Right IZ conv2Idx (DPush des (_, SAccum)) (IS i) = first IS (conv2Idx des i) @@ -371,7 +365,7 @@ zero :: STy t -> Ex env (D2 t) zero STNil = ENil ext zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext) zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext) -zero (STArr n t) = EBuild ext (vecGenerate n (\_ -> EConst ext STI64 0)) (zero t) +zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t) zero (STScal t) = case t of STI32 -> ENil ext STI64 -> ENil ext @@ -464,11 +458,11 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = envpro prosub (\shbinds -> - autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum descr))) + autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) (#pro :++: #d :++: #shb :++: #acc :++: #tl) .> WCopy (wf shbinds) - .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum storepl))) + .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) (#d :++: #shb :++: #acc :++: #tl) (#acc :++: (#d :++: #shb :++: #tl))) @@ -489,7 +483,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = -- goal: | ARE EQUAL || -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) WCopy (wf shbinds) - .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) + .> WPick @(TAccum (D2 (TArr arrn arrt))) @(D2 dt : shbinds) (Const () `SCons` shbindsC) (WId @(D2AcE (Select env1 stoRepl "accum")))) -- "merge" values must be an array, so reject everything else. (TODO: generalise this) @@ -505,10 +499,6 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = -- STScal{} -> False -- STAccum{} -> error "An accumulator in merge storage?" -type family InvTup core env where - InvTup core '[] = core - InvTup core (t : ts) = InvTup (TPair core t) ts - makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) makeAccumulators SNil e = e makeAccumulators (STArr n t `SCons` envpro) e = @@ -753,7 +743,7 @@ d2e (SCons t ts) = SCons (d2 t) (d2e ts) d2ace :: SList STy env -> SList STy (D2AcE env) d2ace SNil = SNil -d2ace (SCons t ts) = SCons (d2ac t) (d2ace ts) +d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts) freezeRet :: Descr env sto -> Ret env sto t @@ -775,11 +765,11 @@ drev :: forall env sto t. drev des = \case EVar _ t i -> case conv2Idx des i of - Left _ -> + Left accI -> Ret BTop (EVar ext (d1 t) (conv1Idx i)) (subenvNone (select SMerge des)) - (ENil ext) + (EAccum SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI))) Right tupI -> Ret BTop @@ -1075,22 +1065,25 @@ drev des = \case -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil - -> - Ret binds - (EIdx1 ext e1 ei1) + , STArr (SS n) eltty <- typeOf e -> + Ret (binds `BPush` (tTup (sreplicate (SS n) tIx), EShape ext e1)) + (weakenExpr WSink (EIdx1 ext e1 ei1)) sub - (_ e2) + (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (EVar ext (STArr n (d2 eltty)) (IS IZ))) $ + weakenExpr (WCopy (WSink .> WSink)) e2) -- These should be the next to be implemented, I think EFold1{} -> err_unsupported "EFold1" EShape{} -> err_unsupported "EShape" - EReplicate{} -> err_unsupported "EReplicate" + -- EReplicate{} -> err_unsupported "EReplicate" EIdx{} -> err_unsupported "EIdx" EBuild{} -> err_unsupported "EBuild" EWith{} -> err_accum - EAccum1{} -> err_accum + EAccum{} -> err_accum where err_accum = error "Accumulator operations unsupported in the source program" diff --git a/src/Data.hs b/src/Data.hs index 8c39c6c..eb6c033 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -5,6 +5,7 @@ {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Data where @@ -25,6 +26,14 @@ sappend :: SList f l1 -> SList f l2 -> SList f (Append l1 l2) sappend SNil l = l sappend (SCons x xs) l = SCons x (sappend xs l) +type family Replicate n x where + Replicate Z x = '[] + Replicate (S n) x = x : Replicate n x + +sreplicate :: SNat n -> f t -> SList f (Replicate n t) +sreplicate SZ _ = SNil +sreplicate (SS n) x = x `SCons` sreplicate n x + data Nat = Z | S Nat deriving (Show, Eq, Ord) diff --git a/src/Example.hs b/src/Example.hs index 86264e1..424351c 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} module Example where import AST @@ -114,6 +115,9 @@ senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"] descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge) +-- x:R n:I |- let a = unit x +-- b = build1 n (\i. let c = idx0 a in c * c) +-- in idx0 (b ! 3) ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) ex6 = ELet ext (EUnit ext (EVar ext (STScal STF32) (IS IZ))) $ @@ -122,3 +126,47 @@ ex6 = bin (OMul STF32) (EVar ext (STScal STF32) IZ) (EVar ext (STScal STF32) IZ)) $ (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (STScal STF32)) IZ) (EConst ext STI64 3))) + +type R = TScal TF32 + +senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] +senv7 = + let tR = STScal STF32 + tpair = STPair tR tR + in tR `SCons` STPair (STPair (STPair STNil tpair) tpair) tpair `SCons` SNil + +descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"] +descr7 = + let tR = STScal STF32 + tpair = STPair tR tR + in DTop `DPush` (STPair (STPair (STPair STNil tpair) tpair) tpair, SMerge) `DPush` (tR, SMerge) + +-- A "neural network" except it's just scalars, not matrices. +-- ps:((((), (R,R)), (R,R)), (R,R)) x:R +-- |- let p1 = snd ps +-- p1' = fst ps +-- x1 = fst p1 * x + snd p1 +-- p2 = snd p1' +-- p2' = fst p1' +-- x2 = fst p2 * x + snd p2 +-- p3 = snd p2' +-- p3' = fst p2' +-- x3 = fst p3 * x + snd p3 +-- in x3 +ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R +ex7 = + let tR = STScal STF32 + tpair = STPair tR tR + + layer :: STy p -> Idx env p -> Idx env R -> Ex env R + layer parst@(STPair t (STPair (STScal STF32) (STScal STF32))) pars inp = + ELet ext (ESnd ext (EVar ext parst pars)) $ + ELet ext (EFst ext (EVar ext parst (IS pars))) $ + ELet ext (bin (OAdd STF32) (bin (OMul STF32) (EFst ext (EVar ext tpair (IS IZ))) + (EVar ext tR (IS (IS inp)))) + (ESnd ext (EVar ext tpair (IS IZ)))) $ + layer t (IS IZ) IZ + layer STNil _ inp = EVar ext tR inp + layer _ _ _ = error "Invalid layer inputs" + + in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) (IS IZ) IZ diff --git a/src/Simplify.hs b/src/Simplify.hs index 698c667..62a3a9c 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -8,8 +8,6 @@ {-# LANGUAGE TypeOperators #-} module Simplify where -import Data.Monoid - import AST import AST.Count import Data @@ -45,9 +43,10 @@ simplify' = \case -- let rotation ELet _ (ELet _ rhs a) b -> - ELet ext (simplify' rhs) $ - ELet ext (simplify' a) $ - weakenExpr (WCopy WSink) (simplify' b) + simplify' $ + ELet ext rhs $ + ELet ext a $ + weakenExpr (WCopy WSink) (simplify' b) -- beta rules for products EFst _ (EPair _ e _) -> simplify' e @@ -57,6 +56,13 @@ simplify' = \case ECase _ (EInl _ _ e) rhs _ -> simplify' (ELet ext e rhs) ECase _ (EInr _ _ e) _ rhs -> simplify' (ELet ext e rhs) + -- let floating to facilitate beta reduction + EFst _ (ELet _ rhs body) -> simplify' (ELet ext rhs (EFst ext body)) + ESnd _ (ELet _ rhs body) -> simplify' (ELet ext rhs (ESnd ext body)) + ECase _ (ELet _ rhs body) e1 e2 -> simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) + EIdx0 _ (ELet _ rhs body) -> simplify' (ELet ext rhs (EIdx0 ext body)) + EIdx1 _ (ELet _ rhs body) e -> simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) + -- TODO: array indexing (index of build, index of fold) -- TODO: constant folding for operations @@ -74,14 +80,15 @@ simplify' = \case EBuild _ n a b -> EBuild ext n (simplify' a) (simplify' b) EFold1 _ a b -> EFold1 ext (simplify' a) (simplify' b) EUnit _ e -> EUnit ext (simplify' e) - EReplicate _ e -> EReplicate ext (simplify' e) + -- EReplicate _ e -> EReplicate ext (simplify' e) EConst _ t v -> EConst ext t v EIdx0 _ e -> EIdx0 ext (simplify' e) EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b) - EIdx _ e es -> EIdx ext (simplify' e) (fmap simplify' es) + EIdx _ n a b -> EIdx ext n (simplify' a) (simplify' b) + EShape _ e -> EShape ext (simplify' e) EOp _ op e -> EOp ext op (simplify' e) EWith e1 e2 -> EWith (simplify' e1) (let ?accumInScope = True in simplify' e2) - EAccum1 e1 e2 e3 -> EAccum1 (simplify' e1) (simplify' e2) (simplify' e3) + EAccum i e1 e2 e3 -> EAccum i (simplify' e1) (simplify' e2) (simplify' e3) EError t s -> EError t s cheapExpr :: Expr x env t -> Bool @@ -108,14 +115,15 @@ hasAdds = \case EBuild _ _ a b -> hasAdds a || hasAdds b EFold1 _ a b -> hasAdds a || hasAdds b EUnit _ e -> hasAdds e - EReplicate _ e -> hasAdds e + -- EReplicate _ e -> hasAdds e EConst _ _ _ -> False EIdx0 _ e -> hasAdds e EIdx1 _ a b -> hasAdds a || hasAdds b - EIdx _ e es -> hasAdds e || getAny (foldMap (Any . hasAdds) es) + EIdx _ _ a b -> hasAdds a || hasAdds b + EShape _ e -> hasAdds e EOp _ _ e -> hasAdds e EWith a b -> hasAdds a || hasAdds b - EAccum1 _ _ _ -> True + EAccum _ _ _ _ -> True EError _ _ -> False checkAccumInScope :: SList STy env -> Bool |