diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-09-05 12:12:57 +0200 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-09-05 12:12:57 +0200 | 
| commit | ff8aa61cfa28f9a8b2b599b7ca6ed9f404d7b377 (patch) | |
| tree | fd1a4a7cae714f3922c43dda03d53479477a1d83 /src | |
| parent | 5ffb110bb5382b31c1acd3910b2064b36eeb2f77 (diff) | |
Generic accumulators
Diffstat (limited to 'src')
| -rw-r--r-- | src/AST.hs | 94 | ||||
| -rw-r--r-- | src/AST/Count.hs | 68 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 62 | ||||
| -rw-r--r-- | src/CHAD.hs | 47 | ||||
| -rw-r--r-- | src/Data.hs | 9 | ||||
| -rw-r--r-- | src/Example.hs | 48 | ||||
| -rw-r--r-- | src/Simplify.hs | 30 | 
7 files changed, 210 insertions, 148 deletions
| @@ -30,7 +30,7 @@ data Ty    | TEither Ty Ty    | TArr Nat Ty  -- ^ rank, element type    | TScal ScalTy -  | TAccum Nat Ty  -- ^ rank and element type of the array being accumulated to +  | TAccum Ty    deriving (Show, Eq, Ord)  data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool @@ -43,7 +43,7 @@ data STy t where    STEither :: STy a -> STy b -> STy (TEither a b)    STArr :: SNat n -> STy t -> STy (TArr n t)    STScal :: SScalTy t -> STy (TScal t) -  STAccum :: SNat n -> STy t -> STy (TAccum n t) +  STAccum :: STy t -> STy (TAccum t)  deriving instance Show (STy t)  data SScalTy t where @@ -66,10 +66,23 @@ type family ScalRep t where    ScalRep TF64 = Double    ScalRep TBool = Bool -type ConsN :: Nat -> a -> [a] -> [a] -type family ConsN n x l where -  ConsN Z x l = l -  ConsN (S n) x l = x : ConsN n x l +-- | This index is flipped around from the usual direction: the smallest index +-- is at the _heart_ of the nesting, not at the outside. The outermost layer +-- indexes into the _outer_ dimension of the type @t@. This makes indices into +-- compound structures work properly with coproducts. +type family AcIdx t i where +  AcIdx t Z = TNil +  AcIdx (TPair a b) (S i) = TEither (AcIdx a i) (AcIdx b i) +  AcIdx (TEither a b) (S i) = TEither (AcIdx a i) (AcIdx b i) +  AcIdx (TArr Z t) (S i) = AcIdx t i +  AcIdx (TArr (S n) t) (S i) = TPair TIx (AcIdx (TArr n t) i) + +type family AcVal t i where +  AcVal t Z = t +  AcVal (TPair a b) (S i) = TEither (AcVal a i) (AcVal b i) +  AcVal (TEither a b) (S i) = TEither (AcVal a i) (AcVal b i) +  AcVal (TArr Z t) (S i) = AcVal t i +  AcVal (TArr (S n) t) (S i) = AcVal (TArr n t) i  -- General assumption: head of the list (whatever way it is associated) is the  -- inner variable / inner array dimension. In pretty printing, the inner @@ -91,22 +104,23 @@ data Expr x env t where    -- array operations    EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t) -  EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t) +  EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)    EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)    EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) -  EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)  -- TODO: unused +  -- EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)  -- TODO: unused    -- expression operations    EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)    EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t    EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -  EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t +  EIdx :: x t -> SNat n -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t    EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx))    EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t    -- accumulation effect -  EWith :: Expr x env (TArr n t) -> Expr x (TAccum n t : env) a -> Expr x env (TPair a (TArr n t)) -  EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil +  EWith :: Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) +  EAccum :: SNat i -> Expr x env (AcIdx t i) -> Expr x env (AcVal t i) -> Expr x env (TAccum t) -> Expr x env TNil +  -- EAccum1 :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum (S Z) t) -> Expr x env TNil    -- partiality    EError :: STy a -> String -> Expr x env a @@ -117,10 +131,6 @@ type Ex = Expr (Const ())  ext :: Const () a  ext = Const () -type family Replicate n x where -  Replicate Z x = '[] -  Replicate (S n) x = x : Replicate n x -  type family Tup env where    Tup '[] = TNil    Tup (t : ts) = TPair (Tup ts) t @@ -129,6 +139,14 @@ tTup :: SList STy env -> STy (Tup env)  tTup SNil = STNil  tTup (SCons t ts) = STPair (tTup ts) t +eTup :: SList (Ex env) list -> Ex env (Tup list) +eTup SNil = ENil ext +eTup (e `SCons` es) = EPair ext (eTup es) e + +type family InvTup core env where +  InvTup core '[] = core +  InvTup core (t : ts) = InvTup (TPair core t) ts +  type SOp :: Ty -> Ty -> Type  data SOp a t where    OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) @@ -169,17 +187,17 @@ typeOf = \case    EBuild _ n _ e -> STArr n (typeOf e)    EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t    EUnit _ e -> STArr SZ (typeOf e) -  EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t +  -- EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t    EConst _ t _ -> STScal t    EIdx0 _ e | STArr _ t <- typeOf e -> t    EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t -  EIdx _ e _ | STArr _ t <- typeOf e -> t -  -- EShape _ e | STArr n _ <- typeOf e -> _ +  EIdx _ _ e _ | STArr _ t <- typeOf e -> t +  EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx)    EOp _ op _ -> opt2 op    EWith e1 e2 -> STPair (typeOf e2) (typeOf e1) -  EAccum1 _ _ _ -> STNil +  EAccum _ _ _ _ -> STNil    EError t _ -> t @@ -194,7 +212,7 @@ unSTy = \case    STEither a b -> TEither (unSTy a) (unSTy b)    STArr n t -> TArr (unSNat n) (unSTy t)    STScal t -> TScal (unSScalTy t) -  STAccum n t -> TAccum (unSNat n) (unSTy t) +  STAccum t -> TAccum (unSTy t)  unSList :: SList STy env -> [Ty]  unSList SNil = [] @@ -231,17 +249,18 @@ subst' f w = \case    EInr x t e -> EInr x t (subst' f w e)    ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)    EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b) -  EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkFN n f) (wcopyN n w) b) +  EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)    EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)    EUnit x e -> EUnit x (subst' f w e) -  EReplicate x e -> EReplicate x (subst' f w e) +  -- EReplicate x e -> EReplicate x (subst' f w e)    EConst x t v -> EConst x t v    EIdx0 x e -> EIdx0 x (subst' f w e)    EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) -  EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es) +  EIdx x n e es -> EIdx x n (subst' f w e) (subst' f w es) +  EShape x e -> EShape x (subst' f w e)    EOp x op e -> EOp x op (subst' f w e)    EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) -  EAccum1 e1 e2 e3 -> EAccum1 (subst' f w e1) (subst' f w e2) (subst' f w e3) +  EAccum i e1 e2 e3 -> EAccum i (subst' f w e1) (subst' f w e2) (subst' f w e3)    EError t s -> EError t s    where      sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) @@ -250,28 +269,9 @@ subst' f w = \case        IZ -> EVar x' t (w' @> IZ)        IS i -> f' x' t (WPop w') i -    sinkFN :: SNat n -           -> (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) -           -> x t -> STy t -> (ConsN n TIx env' :> env2) -> Idx (ConsN n TIx env) t -> Expr x env2 t -    sinkFN SZ f' x t w' i = f' x t w' i -    sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ) -    sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i -  weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t  weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) -wsinkN :: SNat n -> env :> ConsN n TIx env -wsinkN SZ = WId -wsinkN (SS n) = WSink .> wsinkN n - -wcopyN :: SNat n -> env :> env' -> ConsN n TIx env :> ConsN n TIx env' -wcopyN SZ w = w -wcopyN (SS n) w = WCopy (wcopyN n w) - -wpopN :: SNat n -> ConsN n TIx env :> env' -> env :> env' -wpopN SZ w = w -wpopN (SS n) w = wpopN n (WPop w) -  wUndoSubenv :: Subenv env env' -> env' :> env  wUndoSubenv SETop = WId  wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub) @@ -299,11 +299,15 @@ instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair kn  instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy  instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy  instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy -instance (KnownNat n, KnownTy t) => KnownTy (TAccum n t) where knownTy = STAccum knownNat knownTy +instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy  class KnownEnv env where knownEnv :: SList STy env  instance KnownEnv '[] where knownEnv = SNil  instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv  ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) -ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) (error "TODO" f) +ebuildUp1 n sh size f = +  EBuild ext (SS n) (EPair ext sh size) $ +    let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ +    in EIdx ext n (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) +                  (EFst ext arg) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index a4ff9f2..39d26c2 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -36,6 +36,10 @@ data Occ = Occ { _occLexical :: Count    deriving (Eq, Generic)    deriving (Semigroup, Monoid) via Generically Occ +instance Show Occ where +  showsPrec d (Occ l r) = showParen (d > 10) $ +    showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r +  -- | One of the two branches is taken  (<||>) :: Occ -> Occ -> Occ  Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) @@ -47,9 +51,8 @@ scaleMany (Occ l _) = Occ l Many  occCount :: Idx env a -> Expr x env t -> Occ  occCount idx =    getConst . occCountGeneral -    (\i o -> if idx2int i == idx2int idx then Const o else mempty) +    (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty)      (\(Const o) -> Const o) -    (\_ (Const o) -> Const o)      (\(Const o1) (Const o2) -> Const (o1 <||> o2))      (\(Const o) -> Const (scaleMany o)) @@ -84,47 +87,48 @@ occEnvPop (OccPush o _) = o  occEnvPop OccEnd = OccEnd  occCountAll :: Expr x env t -> OccEnv env -occCountAll = occCountGeneral onehotOccEnv occEnvPop occEnvPopN (<||>!) scaleManyOccEnv -  where -    occEnvPopN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env -    occEnvPopN _ OccEnd = OccEnd -    occEnvPopN SZ e = e -    occEnvPopN (SS n) (OccPush e _) = occEnvPopN n e +occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv  occCountGeneral :: forall r env t x.                     (forall env'. Monoid (r env')) -                => (forall env' a. Idx env' a -> Occ -> r env')  -- ^ one-hot +                => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env')  -- ^ one-hot                  -> (forall env' a. r (a : env') -> r env')  -- ^ unpush -                -> (forall env' n. SNat n -> r (ConsN n TIx env') -> r env')  -- ^ unpushN                  -> (forall env'. r env' -> r env' -> r env')  -- ^ alternation                  -> (forall env'. r env' -> r env')  -- ^ scale-many                  -> Expr x env t -> r env -occCountGeneral onehot unpush unpushN alter many = go +occCountGeneral onehot unpush alter many = go WId    where -    go :: Monoid (r env') => Expr x env' t' -> r env' -    go = \case -      EVar _ _ i -> onehot i (Occ One One) -      ELet _ rhs body -> go rhs <> unpush (go body) -      EPair _ a b -> go a <> go b -      EFst _ e -> go e -      ESnd _ e -> go e +    go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env' +    go w = \case +      EVar _ _ i -> onehot w i (Occ One One) +      ELet _ rhs body -> re rhs <> re1 body +      EPair _ a b -> re a <> re b +      EFst _ e -> re e +      ESnd _ e -> re e        ENil _ -> mempty -      EInl _ _ e -> go e -      EInr _ _ e -> go e -      ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b)) -      EBuild1 _ a b -> go a <> many (unpush (go b)) -      EBuild _ n a b -> go a <> many (unpushN n (go b)) -      EFold1 _ a b -> many (unpush (unpush (go a))) <> go b -      EUnit _ e -> go e -      EReplicate _ e -> go e +      EInl _ _ e -> re e +      EInr _ _ e -> re e +      ECase _ e a b -> re e <> (re1 a `alter` re1 b) +      EBuild1 _ a b -> re a <> many (re1 b) +      EBuild _ _ a b -> re a <> many (re1 b) +      EFold1 _ a b -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b +      EUnit _ e -> re e +      -- EReplicate _ e -> re e        EConst{} -> mempty -      EIdx0 _ e -> go e -      EIdx1 _ a b -> go a <> go b -      EIdx _ e es -> go e <> foldMap go es -      EOp _ _ e -> go e -      EWith a b -> go a <> unpush (go b) -      EAccum1 a b e -> go a <> go b <> go e +      EIdx0 _ e -> re e +      EIdx1 _ a b -> re a <> re b +      EIdx _ _ a b -> re a <> re b +      EShape _ e -> re e +      EOp _ _ e -> re e +      EWith a b -> re a <> re1 b +      EAccum _ a b e -> re a <> re b <> re e        EError{} -> mempty +      where +        re :: Monoid (r env') => Expr x env' t'' -> r env' +        re = go w + +        re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env' +        re1 = unpush . go (WSink .> w)  deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index dbbc021..5610d36 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -1,16 +1,15 @@ -{-# LANGUAGE LambdaCase #-}  {-# LANGUAGE DataKinds #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE EmptyCase #-}  {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-}  {-# LANGUAGE PolyKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE DeriveFunctor #-}  {-# LANGUAGE TupleSections #-} -module AST.Pretty where +{-# LANGUAGE TypeOperators #-} +module AST.Pretty (ppExpr) where  import Control.Monad (ap)  import Data.List (intersperse) -import Data.Foldable (toList)  import Data.Functor.Const  import AST @@ -29,10 +28,6 @@ valprj (VPush x _) IZ = x  valprj (VPush _ env) (IS i) = valprj env i  valprj VTop i = case i of {} -vpushN :: Vec n a -> Val (Const a) env -> Val (Const a) (ConsN n TIx env) -vpushN VNil v = v -vpushN (name :< names) v = VPush (Const name) (vpushN names v) -  newtype M a = M { runM :: Int -> (a, Int) }    deriving (Functor)  instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap } @@ -115,12 +110,10 @@ ppExpr' d val = \case    EBuild _ n a b -> do      a' <- ppExpr' 11 val a -    names <- sequence (vecGenerate n (\_ -> genName))  -- TODO generate underscores -    e' <- ppExpr' 0 (vpushN names val) b +    name <- genNameIfUsedIn (tTup (sreplicate n tIx)) IZ b +    e' <- ppExpr' 0 (VPush (Const name) val) b      return $ showParen (d > 10) $ -      showString "build " . a' . showString " (\\[" -      . foldr (.) id (intersperse (showString ",") (map showString (reverse (toList names)))) -      . showString ("] -> ") . e' . showString ")" +      showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")"    EFold1 _ a b -> do      name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a @@ -135,9 +128,9 @@ ppExpr' d val = \case      e' <- ppExpr' 11 val e      return $ showParen (d > 10) $ showString "unit " . e' -  EReplicate _ e -> do -    e' <- ppExpr' 11 val e -    return $ showParen (d > 10) $ showString "replicate " . e' +  -- EReplicate _ e -> do +  --   e' <- ppExpr' 11 val e +  --   return $ showParen (d > 10) $ showString "replicate " . e'    EConst _ ty v -> return $ showString $ case ty of      STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v @@ -151,14 +144,15 @@ ppExpr' d val = \case      b' <- ppExpr' 9 val b      return $ showParen (d > 8) $ a' . showString " ! " . b' -  EIdx _ e es -> do -    e' <- ppExpr' 9 val e -    es' <- traverse (ppExpr' 0 val) es +  EIdx _ _ a b -> do +    a' <- ppExpr' 9 val a +    b' <- ppExpr' 10 val b      return $ showParen (d > 8) $ -      e' . showString " ! " -      . showString "[" -      . foldr (.) id (intersperse (showString ", ") (reverse (toList es'))) -      . showString "]" +      a' . showString " !! " . b' + +  EShape _ e -> do +    e' <- ppExpr' 11 val e +    return $ showParen (d > 10) $ showString "shape " . e'    EOp _ op (EPair _ a b)      | (Infix, ops) <- operator op -> do @@ -175,30 +169,30 @@ ppExpr' d val = \case    EWith e1 e2 -> do      e1' <- ppExpr' 11 val e1 -    let STArr n t = typeOf e1 -    name <- genNameIfUsedIn' "ac" (STAccum n t) IZ e2 -    e2' <- ppExpr' 11 (VPush (Const name) val) e2 +    name <- genNameIfUsedIn' "ac" (STAccum (typeOf e1)) IZ e2 +    e2' <- ppExpr' 0 (VPush (Const name) val) e2      return $ showParen (d > 10) $        showString "with " . e1' . showString (" (\\" ++ name ++ " -> ")        . e2' . showString ")" -  EAccum1 e1 e2 e3 -> do +  EAccum i e1 e2 e3 -> do      e1' <- ppExpr' 11 val e1      e2' <- ppExpr' 11 val e2      e3' <- ppExpr' 11 val e3      return $ showParen (d > 10) $ -      showString "accum1 " . e1' . showString " " . e2' . showString " " . e3' +      showString ("accum " ++ show (unSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3'    EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)  ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS  ppExprLet d val etop = do -  let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS) +  let collect :: SVal env -> Expr x env t -> M ([(String, Occ, ShowS)], ShowS)        collect val' (ELet _ rhs body) = do +        let occ = occCount IZ body          name <- genNameIfUsedIn (typeOf rhs) IZ body          rhs' <- ppExpr' 0 val' rhs          (binds, core) <- collect (VPush (Const name) val') body -        return ((name, rhs') : binds, core) +        return ((name, occ, rhs') : binds, core)        collect val' e = ([],) <$> ppExpr' 0 val' e    (binds, core) <- collect val etop @@ -210,7 +204,9 @@ ppExprLet d val etop = do      showString ("let " ++ open)      . foldr (.) id          (intersperse (showString " ; ") -           (map (\(name, rhs) -> showString (name ++ " = ") . rhs) binds)) +           (map (\(name, _occ, rhs) -> +                    showString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") . rhs) +                binds))      . showString (close ++ " in ")      . core diff --git a/src/CHAD.hs b/src/CHAD.hs index 087a26e..692bb96 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -309,9 +309,6 @@ type family D2s t where    D2s TF64 = TScal TF64    D2s TBool = TNil -type family D2Ac t where -  D2Ac (TArr n t) = TAccum n (D2 t) -  type family D1E env where    D1E '[] = '[]    D1E (t : env) = D1 t : D1E env @@ -322,7 +319,7 @@ type family D2E env where  type family D2AcE env where    D2AcE '[] = '[] -  D2AcE (t : env) = D2Ac t : D2AcE env +  D2AcE (t : env) = TAccum (D2 t) : D2AcE env  -- | Select only the types from the environment that have the specified storage  type family Select env sto s where @@ -351,16 +348,13 @@ d2 (STScal t) = case t of    STBool -> STNil  d2 STAccum{} = error "Accumulators not allowed in input program" -d2ac :: STy t -> STy (D2Ac t) -d2ac (STArr n t) = STAccum n (d2 t) -d2ac _ = error "Only arrays may appear in the accumulator environment" -  conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)  conv1Idx IZ = IZ  conv1Idx (IS i) = IS (conv1Idx i) -conv2Idx :: Descr env sto -> Idx env t -> Either (Idx (D2E (Select env sto "accum")) (D2 t)) -                                                 (Idx (Select env sto "merge") t) +conv2Idx :: Descr env sto -> Idx env t +         -> Either (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) +                   (Idx (Select env sto "merge") t)  conv2Idx (DPush _   (_, SAccum)) IZ = Left IZ  conv2Idx (DPush _   (_, SMerge)) IZ = Right IZ  conv2Idx (DPush des (_, SAccum)) (IS i) = first IS (conv2Idx des i) @@ -371,7 +365,7 @@ zero :: STy t -> Ex env (D2 t)  zero STNil = ENil ext  zero (STPair t1 t2) = EInl ext (STPair (d2 t1) (d2 t2)) (ENil ext)  zero (STEither t1 t2) = EInl ext (STEither (d2 t1) (d2 t2)) (ENil ext) -zero (STArr n t) = EBuild ext (vecGenerate n (\_ -> EConst ext STI64 0)) (zero t) +zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t)  zero (STScal t) = case t of    STI32 -> ENil ext    STI64 -> ENil ext @@ -464,11 +458,11 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =            envpro            prosub            (\shbinds -> -            autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum descr))) +            autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr)))                       (#acc :++: (#pro :++: #d :++: #shb :++: #tl))                       (#pro :++: #d :++: #shb :++: #acc :++: #tl)              .> WCopy (wf shbinds) -            .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum storepl))) +            .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl)))                          (#d :++: #shb :++: #acc :++: #tl)                          (#acc :++: (#d :++: #shb :++: #tl))) @@ -489,7 +483,7 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =                -- goal:                        |                                                                 ARE EQUAL  ||                --   D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum"))  :>  TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))                WCopy (wf shbinds) -              .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) +              .> WPick @(TAccum (D2 (TArr arrn arrt))) @(D2 dt : shbinds) (Const () `SCons` shbindsC)                     (WId @(D2AcE (Select env1 stoRepl "accum"))))          -- "merge" values must be an array, so reject everything else. (TODO: generalise this) @@ -505,10 +499,6 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =    --     STScal{} -> False    --     STAccum{} -> error "An accumulator in merge storage?" -type family InvTup core env where -  InvTup core '[] = core -  InvTup core (t : ts) = InvTup (TPair core t) ts -  makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))  makeAccumulators SNil e = e  makeAccumulators (STArr n t `SCons` envpro) e = @@ -753,7 +743,7 @@ d2e (SCons t ts) = SCons (d2 t) (d2e ts)  d2ace :: SList STy env -> SList STy (D2AcE env)  d2ace SNil = SNil -d2ace (SCons t ts) = SCons (d2ac t) (d2ace ts) +d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts)  freezeRet :: Descr env sto            -> Ret env sto t @@ -775,11 +765,11 @@ drev :: forall env sto t.  drev des = \case    EVar _ t i ->      case conv2Idx des i of -      Left _ -> +      Left accI ->          Ret BTop              (EVar ext (d1 t) (conv1Idx i))              (subenvNone (select SMerge des)) -            (ENil ext) +            (EAccum SZ (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum (d2 t)) (IS accI)))        Right tupI ->          Ret BTop @@ -1075,22 +1065,25 @@ drev des = \case      -- We're allowed to ignore ei2 here because the output of 'ei' is discrete.      | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)          <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil -    -> -    Ret binds -        (EIdx1 ext e1 ei1) +    , STArr (SS n) eltty <- typeOf e -> +    Ret (binds `BPush` (tTup (sreplicate (SS n) tIx), EShape ext e1)) +        (weakenExpr WSink (EIdx1 ext e1 ei1))          sub -        (_ e2) +        (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) +                               (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) +                               (EVar ext (STArr n (d2 eltty)) (IS IZ))) $ +         weakenExpr (WCopy (WSink .> WSink)) e2)    -- These should be the next to be implemented, I think    EFold1{} -> err_unsupported "EFold1"    EShape{} -> err_unsupported "EShape" -  EReplicate{} -> err_unsupported "EReplicate" +  -- EReplicate{} -> err_unsupported "EReplicate"    EIdx{} -> err_unsupported "EIdx"    EBuild{} -> err_unsupported "EBuild"    EWith{} -> err_accum -  EAccum1{} -> err_accum +  EAccum{} -> err_accum    where      err_accum = error "Accumulator operations unsupported in the source program" diff --git a/src/Data.hs b/src/Data.hs index 8c39c6c..eb6c033 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -5,6 +5,7 @@  {-# LANGUAGE QuantifiedConstraints #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  module Data where @@ -25,6 +26,14 @@ sappend :: SList f l1 -> SList f l2 -> SList f (Append l1 l2)  sappend SNil l = l  sappend (SCons x xs) l = SCons x (sappend xs l) +type family Replicate n x where +  Replicate Z x = '[] +  Replicate (S n) x = x : Replicate n x + +sreplicate :: SNat n -> f t -> SList f (Replicate n t) +sreplicate SZ _ = SNil +sreplicate (SS n) x = x `SCons` sreplicate n x +  data Nat = Z | S Nat    deriving (Show, Eq, Ord) diff --git a/src/Example.hs b/src/Example.hs index 86264e1..424351c 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -1,4 +1,5 @@  {-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-}  module Example where  import AST @@ -114,6 +115,9 @@ senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil  descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"]  descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge) +-- x:R n:I |- let a = unit x +--                b = build1 n (\i. let c = idx0 a in c * c) +--            in idx0 (b ! 3)  ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32)  ex6 =    ELet ext (EUnit ext (EVar ext (STScal STF32) (IS IZ))) $ @@ -122,3 +126,47 @@ ex6 =                  bin (OMul STF32) (EVar ext (STScal STF32) IZ)                                   (EVar ext (STScal STF32) IZ)) $      (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (STScal STF32)) IZ) (EConst ext STI64 3))) + +type R = TScal TF32 + +senv7 :: SList STy [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] +senv7 = +  let tR = STScal STF32 +      tpair = STPair tR tR +  in tR `SCons` STPair (STPair (STPair STNil tpair) tpair) tpair `SCons` SNil + +descr7 :: Descr [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] ["merge", "merge"] +descr7 = +  let tR = STScal STF32 +      tpair = STPair tR tR +  in DTop `DPush` (STPair (STPair (STPair STNil tpair) tpair) tpair, SMerge) `DPush` (tR, SMerge) + +-- A "neural network" except it's just scalars, not matrices. +-- ps:((((), (R,R)), (R,R)), (R,R)) x:R +-- |- let p1 = snd ps +--        p1' = fst ps +--        x1 = fst p1 * x + snd p1 +--        p2 = snd p1' +--        p2' = fst p1' +--        x2 = fst p2 * x + snd p2 +--        p3 = snd p2' +--        p3' = fst p2' +--        x3 = fst p3 * x + snd p3 +--    in x3 +ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R +ex7 = +  let tR = STScal STF32 +      tpair = STPair tR tR + +      layer :: STy p -> Idx env p -> Idx env R -> Ex env R +      layer parst@(STPair t (STPair (STScal STF32) (STScal STF32))) pars inp = +        ELet ext (ESnd ext (EVar ext parst pars)) $ +        ELet ext (EFst ext (EVar ext parst (IS pars))) $ +        ELet ext (bin (OAdd STF32) (bin (OMul STF32) (EFst ext (EVar ext tpair (IS IZ))) +                                                     (EVar ext tR (IS (IS inp)))) +                                   (ESnd ext (EVar ext tpair (IS IZ)))) $ +          layer t (IS IZ) IZ +      layer STNil _ inp = EVar ext tR inp +      layer _ _ _ = error "Invalid layer inputs" + +  in layer (STPair (STPair (STPair STNil tpair) tpair) tpair) (IS IZ) IZ diff --git a/src/Simplify.hs b/src/Simplify.hs index 698c667..62a3a9c 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -8,8 +8,6 @@  {-# LANGUAGE TypeOperators #-}  module Simplify where -import Data.Monoid -  import AST  import AST.Count  import Data @@ -45,9 +43,10 @@ simplify' = \case    -- let rotation    ELet _ (ELet _ rhs a) b -> -    ELet ext (simplify' rhs) $ -    ELet ext (simplify' a) $ -      weakenExpr (WCopy WSink) (simplify' b) +    simplify' $ +      ELet ext rhs $ +      ELet ext a $ +        weakenExpr (WCopy WSink) (simplify' b)    -- beta rules for products    EFst _ (EPair _ e _) -> simplify' e @@ -57,6 +56,13 @@ simplify' = \case    ECase _ (EInl _ _ e) rhs _ -> simplify' (ELet ext e rhs)    ECase _ (EInr _ _ e) _ rhs -> simplify' (ELet ext e rhs) +  -- let floating to facilitate beta reduction +  EFst _ (ELet _ rhs body) -> simplify' (ELet ext rhs (EFst ext body)) +  ESnd _ (ELet _ rhs body) -> simplify' (ELet ext rhs (ESnd ext body)) +  ECase _ (ELet _ rhs body) e1 e2 -> simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) +  EIdx0 _ (ELet _ rhs body) -> simplify' (ELet ext rhs (EIdx0 ext body)) +  EIdx1 _ (ELet _ rhs body) e -> simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) +    -- TODO: array indexing (index of build, index of fold)    -- TODO: constant folding for operations @@ -74,14 +80,15 @@ simplify' = \case    EBuild _ n a b -> EBuild ext n (simplify' a) (simplify' b)    EFold1 _ a b -> EFold1 ext (simplify' a) (simplify' b)    EUnit _ e -> EUnit ext (simplify' e) -  EReplicate _ e -> EReplicate ext (simplify' e) +  -- EReplicate _ e -> EReplicate ext (simplify' e)    EConst _ t v -> EConst ext t v    EIdx0 _ e -> EIdx0 ext (simplify' e)    EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b) -  EIdx _ e es -> EIdx ext (simplify' e) (fmap simplify' es) +  EIdx _ n a b -> EIdx ext n (simplify' a) (simplify' b) +  EShape _ e -> EShape ext (simplify' e)    EOp _ op e -> EOp ext op (simplify' e)    EWith e1 e2 -> EWith (simplify' e1) (let ?accumInScope = True in simplify' e2) -  EAccum1 e1 e2 e3 -> EAccum1 (simplify' e1) (simplify' e2) (simplify' e3) +  EAccum i e1 e2 e3 -> EAccum i (simplify' e1) (simplify' e2) (simplify' e3)    EError t s -> EError t s  cheapExpr :: Expr x env t -> Bool @@ -108,14 +115,15 @@ hasAdds = \case    EBuild _ _ a b -> hasAdds a || hasAdds b    EFold1 _ a b -> hasAdds a || hasAdds b    EUnit _ e -> hasAdds e -  EReplicate _ e -> hasAdds e +  -- EReplicate _ e -> hasAdds e    EConst _ _ _ -> False    EIdx0 _ e -> hasAdds e    EIdx1 _ a b -> hasAdds a || hasAdds b -  EIdx _ e es -> hasAdds e || getAny (foldMap (Any . hasAdds) es) +  EIdx _ _ a b -> hasAdds a || hasAdds b +  EShape _ e -> hasAdds e    EOp _ _ e -> hasAdds e    EWith a b -> hasAdds a || hasAdds b -  EAccum1 _ _ _ -> True +  EAccum _ _ _ _ -> True    EError _ _ -> False  checkAccumInScope :: SList STy env -> Bool | 
