diff options
| -rw-r--r-- | chad-fast.cabal | 1 | ||||
| -rw-r--r-- | src/AST.hs | 32 | ||||
| -rw-r--r-- | src/AST/Count.hs | 28 | ||||
| -rw-r--r-- | src/AST/Env.hs | 43 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 16 | ||||
| -rw-r--r-- | src/AST/Weaken.hs | 6 | ||||
| -rw-r--r-- | src/CHAD.hs | 207 | ||||
| -rw-r--r-- | src/Example.hs | 15 | ||||
| -rw-r--r-- | src/Simplify.hs | 6 | 
9 files changed, 240 insertions, 114 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 19c2852..1bff84b 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -12,6 +12,7 @@ library    exposed-modules:      AST      AST.Count +    AST.Env      AST.Pretty      AST.Weaken      AST.Weaken.Auto @@ -19,6 +19,7 @@ import Data.Functor.Const  import Data.Kind (Type)  import Data.Int +import AST.Env  import AST.Weaken  import Data @@ -90,15 +91,17 @@ data Expr x env t where    -- array operations    EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t) -  EBuild :: x (TArr n t) -> Vec n (Expr x env TIx) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t) +  EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t)    EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)    EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) +  EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)  -- TODO: unused    -- expression operations    EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)    EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t    EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)    EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t +  EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx))    EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t    -- accumulation effect @@ -114,6 +117,18 @@ type Ex = Expr (Const ())  ext :: Const () a  ext = Const () +type family Replicate n x where +  Replicate Z x = '[] +  Replicate (S n) x = x : Replicate n x + +type family Tup env where +  Tup '[] = TNil +  Tup (t : ts) = TPair (Tup ts) t + +tTup :: SList STy env -> STy (Tup env) +tTup SNil = STNil +tTup (SCons t ts) = STPair (tTup ts) t +  type SOp :: Ty -> Ty -> Type  data SOp a t where    OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) @@ -151,14 +166,16 @@ typeOf = \case    ECase _ _ a _ -> typeOf a    EBuild1 _ _ e -> STArr (SS SZ) (typeOf e) -  EBuild _ es e -> STArr (vecLength es) (typeOf e) +  EBuild _ n _ e -> STArr n (typeOf e)    EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t    EUnit _ e -> STArr SZ (typeOf e) +  EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t    EConst _ t _ -> STScal t    EIdx0 _ e | STArr _ t <- typeOf e -> t    EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t    EIdx _ e _ | STArr _ t <- typeOf e -> t +  -- EShape _ e | STArr n _ <- typeOf e -> _    EOp _ op _ -> opt2 op    EWith e1 e2 -> STPair (typeOf e2) (typeOf e1) @@ -214,9 +231,10 @@ subst' f w = \case    EInr x t e -> EInr x t (subst' f w e)    ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)    EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b) -  EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e) +  EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkFN n f) (wcopyN n w) b)    EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)    EUnit x e -> EUnit x (subst' f w e) +  EReplicate x e -> EReplicate x (subst' f w e)    EConst x t v -> EConst x t v    EIdx0 x e -> EIdx0 x (subst' f w e)    EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) @@ -254,6 +272,11 @@ wpopN :: SNat n -> ConsN n TIx env :> env' -> env :> env'  wpopN SZ w = w  wpopN (SS n) w = wpopN n (WPop w) +wUndoSubenv :: Subenv env env' -> env' :> env +wUndoSubenv SETop = WId +wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub) +wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub +  slistIdx :: SList f list -> Idx list t -> f t  slistIdx (SCons x _) IZ = x  slistIdx (SCons _ list) (IS i) = slistIdx list i @@ -281,3 +304,6 @@ instance (KnownNat n, KnownTy t) => KnownTy (TAccum n t) where knownTy = STAccum  class KnownEnv env where knownEnv :: SList STy env  instance KnownEnv '[] where knownEnv = SNil  instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv + +ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) +ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) (error "TODO" f) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 289c1fb..a4ff9f2 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -2,12 +2,14 @@  {-# LANGUAGE DeriveGeneric #-}  {-# LANGUAGE DerivingStrategies #-}  {-# LANGUAGE DerivingVia #-} +{-# LANGUAGE EmptyCase #-}  {-# LANGUAGE GADTs #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE QuantifiedConstraints #-}  {-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-}  {-# LANGUAGE TypeOperators #-}  module AST.Count where @@ -15,6 +17,7 @@ import Data.Functor.Const  import GHC.Generics (Generic, Generically(..))  import AST +import AST.Env  import Data @@ -110,9 +113,10 @@ occCountGeneral onehot unpush unpushN alter many = go        EInr _ _ e -> go e        ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b))        EBuild1 _ a b -> go a <> many (unpush (go b)) -      EBuild _ es e -> foldMap go es <> many (unpushN (vecLength es) (go e)) +      EBuild _ n a b -> go a <> many (unpushN n (go b))        EFold1 _ a b -> many (unpush (unpush (go a))) <> go b        EUnit _ e -> go e +      EReplicate _ e -> go e        EConst{} -> mempty        EIdx0 _ e -> go e        EIdx1 _ a b -> go a <> go b @@ -121,3 +125,25 @@ occCountGeneral onehot unpush unpushN alter many = go        EWith a b -> go a <> unpush (go b)        EAccum1 a b e -> go a <> go b <> go e        EError{} -> mempty + + +deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r +deleteUnused SNil OccEnd k = k SETop +deleteUnused (_ `SCons` env) OccEnd k = +  deleteUnused env OccEnd $ \sub -> k (SENo sub) +deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k = +  deleteUnused env occenv $ \sub -> +    case count of Zero -> k (SENo sub) +                  _    -> k (SEYes sub) + +unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t +unsafeWeakenWithSubenv = \sub -> +  subst (\x t i -> case sinkWithSubenv i sub of +                     Just i' -> EVar x t i' +                     Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") +  where +    sinkWithSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) +    sinkWithSubenv IZ (SEYes _) = Just IZ +    sinkWithSubenv IZ (SENo _) = Nothing +    sinkWithSubenv (IS i) (SEYes sub) = IS <$> sinkWithSubenv i sub +    sinkWithSubenv (IS i) (SENo sub) = sinkWithSubenv i sub diff --git a/src/AST/Env.hs b/src/AST/Env.hs new file mode 100644 index 0000000..c33bad3 --- /dev/null +++ b/src/AST/Env.hs @@ -0,0 +1,43 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +module AST.Env where + +import AST.Weaken +import Data + + +-- | @env'@ is a subset of @env@: each element of @env@ is either included in +-- @env'@ ('SEYes') or not included in @env'@ ('SENo'). +data Subenv env env' where +  SETop :: Subenv '[] '[] +  SEYes :: Subenv env env' -> Subenv (t : env) (t : env') +  SENo  :: Subenv env env' -> Subenv (t : env) env' +deriving instance Show (Subenv env env') + +subList :: SList f env -> Subenv env env' -> SList f env' +subList SNil SETop = SNil +subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub) +subList (SCons _ xs) (SENo sub) = subList xs sub + +subenvAll :: SList f env -> Subenv env env +subenvAll SNil = SETop +subenvAll (SCons _ env) = SEYes (subenvAll env) + +subenvNone :: SList f env -> Subenv env '[] +subenvNone SNil = SETop +subenvNone (SCons _ env) = SENo (subenvNone env) + +subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t] +subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env) +subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) +subenvOnehot SNil i = case i of {} + +subenvCompose :: Subenv env1 env2 -> Subenv env2 env3 -> Subenv env1 env3 +subenvCompose SETop SETop = SETop +subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2) +subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) +subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 1dc9dd3..dbbc021 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -113,14 +113,12 @@ ppExpr' d val = \case      return $ showParen (d > 10) $        showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")" -  EBuild _ es e -> do -    es' <- mapM (ppExpr' 0 val) es -    names <- mapM (const genName) es  -- TODO generate underscores -    e' <- ppExpr' 0 (vpushN names val) e +  EBuild _ n a b -> do +    a' <- ppExpr' 11 val a +    names <- sequence (vecGenerate n (\_ -> genName))  -- TODO generate underscores +    e' <- ppExpr' 0 (vpushN names val) b      return $ showParen (d > 10) $ -      showString "build [" -      . foldr (.) id (intersperse (showString ", ") (reverse (toList es'))) -      . showString "] (\\[" +      showString "build " . a' . showString " (\\["        . foldr (.) id (intersperse (showString ",") (map showString (reverse (toList names))))        . showString ("] -> ") . e' . showString ")" @@ -137,6 +135,10 @@ ppExpr' d val = \case      e' <- ppExpr' 11 val e      return $ showParen (d > 10) $ showString "unit " . e' +  EReplicate _ e -> do +    e' <- ppExpr' 11 val e +    return $ showParen (d > 10) $ showString "replicate " . e' +    EConst _ ty v -> return $ showString $ case ty of      STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index e0b5232..78276ca 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -39,7 +39,7 @@ splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i)  data env :> env' where    WId :: env :> env    WSink :: forall t env. env :> (t : env) -  WCopy :: env :> env' -> (t : env) :> (t : env') +  WCopy :: forall t env env'. env :> env' -> (t : env) :> (t : env')    WPop :: (t : env) :> env' -> env :> env'    WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3    WClosed :: SList (Const ()) env -> '[] :> env @@ -95,6 +95,10 @@ wSinks :: forall env bs f. SList f bs -> env :> Append bs env  wSinks SNil = WId  wSinks (SCons _ spine) = WSink .> wSinks spine +wSinksAnd :: forall env env' bs f. SList f bs -> env :> env' -> env :> Append bs env' +wSinksAnd SNil w = w +wSinksAnd (SCons _ spine) w = WSink .> wSinksAnd spine w +  wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2  wCopies SNil w = w  wCopies (SCons _ spine) w = WCopy (wCopies spine w) diff --git a/src/CHAD.hs b/src/CHAD.hs index 007ffe3..087a26e 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -30,10 +30,12 @@ module CHAD (  import Data.Bifunctor (first, second)  import Data.Functor.Const  import Data.Kind (Type) +import GHC.Stack (HasCallStack)  import GHC.TypeLits (Symbol)  import AST  import AST.Count +import AST.Env  import AST.Weaken.Auto  import Data  import Lemmas @@ -422,14 +424,6 @@ plusSparse t a b adder =            (EVar ext t (IS IZ))            (weakenExpr (WCopy (WCopy WSink)) adder))) -type family Tup env where -  Tup '[] = TNil -  Tup (t : ts) = TPair (Tup ts) t - -tTup :: SList STy env -> STy (Tup env) -tTup SNil = STNil -tTup (SCons t ts) = STPair (tTup ts) t -  zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))  zeroTup SNil = ENil ext  zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t) @@ -437,18 +431,20 @@ zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t)  accumPromote :: forall dt env sto proxy r.                  proxy dt               -> Descr env sto -             -> OccEnv env               -> (forall stoRepl envPro. -                    Descr env stoRepl +                    (Select env stoRepl "merge" ~ '[]) +                 => Descr env stoRepl                        -- ^ A revised environment description that switches                        -- arrays (used in the OccEnv) that are currently on -                      -- "merge" storage, to "accum" storage. -                 -> Subenv (Select env sto "merge") (Select env stoRepl "merge") -                      -- ^ The new storage has fewer "merge"-storage entries. +                      -- "merge" storage, to "accum" storage. Any other "merge" +                      -- entries are deleted.                   -> SList STy envPro                        -- ^ New entries on top of the original dual environment,                        -- that house the accumulators for the promoted arrays in                        -- the original environment. +                 -> Subenv (Select env sto "merge") envPro +                      -- ^ The promoted entries were merge entries in the +                      -- original environment.                   -> (forall shbinds.                              SList STy shbinds                           -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) @@ -458,16 +454,15 @@ accumPromote :: forall dt env sto proxy r.                        -- extended with some accumulators.                   -> r)               -> r -accumPromote _ DTop _ k = k DTop SETop SNil (\_ -> WId) -accumPromote _ descr OccEnd k = k descr (subenvAll (select SMerge descr)) SNil (\_ -> WId) -accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k = -  accumPromote pdty descr occenv $ \(storepl :: Descr env1 stoRepl) mergesub (envpro :: SList _ envPro) wf -> -    case (t, sto, occ) of +accumPromote _ DTop k = k DTop SNil SETop (\_ -> WId) +accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = +  accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub wf -> +    case sto of        -- Accumulators are left as-is -      (_, SAccum, _) -> +      SAccum ->          k (storepl `DPush` (t, SAccum)) -          mergesub            envpro +          prosub            (\shbinds ->              autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum descr)))                       (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) @@ -477,34 +472,29 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k =                          (#d :++: #shb :++: #acc :++: #tl)                          (#acc :++: (#d :++: #shb :++: #tl))) -      -- Arrays with "merge" storage and non-zero usage are promoted to an accumulator in envPro -      (STArr (arrn :: SNat arrn) (arrt :: STy arrt), SMerge, Occ _ c) | c > Zero -> -        k (storepl `DPush` (t, SAccum)) -          (SENo mergesub) -          (STArr arrn arrt `SCons` envpro) -          (\(shbinds :: SList _ shbinds) -> -            let shbindsC = slistMap (\_ -> Const ()) shbinds -            in -            -- wf: -            --                 D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum"))  :>                Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) -            -- WCopy wf: -            --   TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum"))  :>  TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) -            --                       WPICK: ^                                                                 THESE TWO  || -            -- goal:                        |                                                                 ARE EQUAL  || -            --   D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum"))  :>  TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) -            WCopy (wf shbinds) -            .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) -                 (WId @(D2AcE (Select env1 stoRepl "accum")))) +      SMerge -> case t of +        -- Arrays with "merge" storage are promoted to an accumulator in envPro +        STArr (arrn :: SNat arrn) (arrt :: STy arrt) -> +          k (storepl `DPush` (t, SAccum)) +            (STArr arrn arrt `SCons` envpro) +            (SEYes prosub) +            (\(shbinds :: SList _ shbinds) -> +              let shbindsC = slistMap (\_ -> Const ()) shbinds +              in +              -- wf: +              --                 D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum"))  :>                Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) +              -- WCopy wf: +              --   TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum"))  :>  TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) +              --                       WPICK: ^                                                                 THESE TWO  || +              -- goal:                        |                                                                 ARE EQUAL  || +              --   D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum"))  :>  TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) +              WCopy (wf shbinds) +              .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) +                   (WId @(D2AcE (Select env1 stoRepl "accum")))) -      -- Used "merge" values must be an array, so reject everything else. (TODO: generalise this) -      (_, SMerge, Occ _ c) -        | c > Zero -> -            error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t -        | otherwise -> -            k (storepl `DPush` (t, SMerge)) -              (SEYes mergesub) -              envpro -              wf +        -- "merge" values must be an array, so reject everything else. (TODO: generalise this) +        _ -> +          error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t    -- where    --   containsTArr :: STy t' -> Bool    --   containsTArr = \case @@ -537,14 +527,6 @@ uninvertTup (t `SCons` list) tcore e =              (ESnd ext (EVar ext recT IZ))              (ESnd ext (EFst ext (EVar ext recT IZ)))) --- | @env'@ is a subset of @env@: each element of @env@ is either included in --- @env'@ ('SEYes') or not included in @env'@ ('SENo'). -data Subenv env env' where -  SETop :: Subenv '[] '[] -  SEYes :: Subenv env env' -> Subenv (t : env) (t : env') -  SENo  :: Subenv env env' -> Subenv (t : env) env' -deriving instance Show (Subenv env env') -  data Ret env0 sto t =    forall shbinds env0Merge.      Ret (Bindings Ex (D1E env0) shbinds)  -- shared binds @@ -566,24 +548,6 @@ data Rets env0 sto env list =           (SList (RetPair env0 sto env shbinds) list)  deriving instance Show (Rets env0 sto env list) -subList :: SList f env -> Subenv env env' -> SList f env' -subList SNil SETop = SNil -subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub) -subList (SCons _ xs) (SENo sub) = subList xs sub - -subenvAll :: SList f env -> Subenv env env -subenvAll SNil = SETop -subenvAll (SCons _ env) = SEYes (subenvAll env) - -subenvNone :: SList f env -> Subenv env '[] -subenvNone SNil = SETop -subenvNone (SCons _ env) = SENo (subenvNone env) - -subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t] -subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env) -subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) -subenvOnehot SNil i = case i of {} -  subenvPlus :: SList STy env             -> Subenv env env1 -> Subenv env env2             -> (forall env3. Subenv env env3 @@ -631,7 +595,7 @@ expandSubenvZeros (SCons t ts) (SEYes sub) e =      in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var)  expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t) -assertSubenvEmpty :: Subenv env env' -> env' :~: '[] +assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[]  assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl  assertSubenvEmpty SETop = Refl  assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" @@ -748,6 +712,10 @@ data Descr env sto where    DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto)  deriving instance Show (Descr env sto) +descrList :: Descr env sto -> SList STy env +descrList DTop = SNil +descrList (des `DPush` (t, _)) = t `SCons` descrList des +  select :: Storage s -> Descr env sto -> SList STy (Select env sto s)  select _ DTop = SNil  select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) @@ -755,6 +723,26 @@ select s@SMerge (DPush des (_, SAccum)) = select s des  select s@SAccum (DPush des (_, SMerge)) = select s des  select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) +-- | This could have more precise typing on the output storage. +subDescr :: Descr env sto -> Subenv env env' +         -> (forall sto'. Descr env' sto' +                       -> Subenv (Select env sto "merge") (Select env' sto' "merge") +                       -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) +                       -> Subenv (D1E env) (D1E env') +                       -> r) +         -> r +subDescr DTop SETop k = k DTop SETop SETop SETop +subDescr (des `DPush` (t, sto)) (SEYes sub) k = +  subDescr des sub $ \des' submerge subaccum subd1e -> +    case sto of +      SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) +      SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) +subDescr (des `DPush` (_, sto)) (SENo sub) k = +  subDescr des sub $ \des' submerge subaccum subd1e -> +    case sto of +      SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) +      SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) +  sD1eEnv :: Descr env sto -> SList STy (D1E env)  sD1eEnv DTop = SNil  sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d) @@ -990,16 +978,18 @@ drev des = \case          (subenvNone (select SMerge des))          (ENil ext) -  EBuild1 _ ne e -    | Ret (ne0 :: Bindings _ _ ne_binds) ne1 nsub ne2 <- drev des ne -    , let eltty = typeOf e -> -    accumPromote eltty des (occEnvPop (occCountAll e)) $ \vdes proSub envPro wPro -> -    case drev (vdes `DPush` (tIx, SMerge)) e of { Ret e0 e1 sub e2 -> +  EBuild1 _ ne (orige :: Ex _ eltty) +    | Ret (ne0 :: Bindings _ _ ne_binds) ne1 _ _ <- drev des ne  -- allowed to ignore ne2 here because ne has a discrete result +    , let eltty = typeOf orige -> +    deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> +    let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in +    subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> +    accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> +    case drev (prodes `DPush` (tIx, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) e1 sub e2 ->      case assertSubenvEmpty sub of { Refl -> -    case assertSubenvEmpty proSub of { Refl -> -    let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv des) IZ e0 in +    let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv usedDes) IZ e0 in      Ret (bconcat (ne0 `BPush` (tIx, ne1)) -                 (fst (weakenBindings weakenExpr (WCopy (wSinks (bindingsBinds ne0))) ve0))) +                 (fst (weakenBindings weakenExpr (WCopy (wSinksAnd (bindingsBinds ne0) (wUndoSubenv subD1eUsed))) ve0)))          (EBuild1 ext             (weakenExpr (autoWeak (#ve0 (bindingsBinds ve0)                                    &. #binds (tIx `SCons` bindingsBinds ne0) @@ -1007,7 +997,7 @@ drev des = \case                                   #binds                                   ((#ve0 :++: #binds) :++: #tl))                (EVar ext tIx IZ)) -           (subst (\_ t i -> case splitIdx @(TIx : D1E env) (bindingsBinds e0) i of +           (subst (\_ t i -> case splitIdx @(TIx : D1E env') (bindingsBinds e0) i of                                 Left ibind ->                                   let ibind' =                                         autoWeak (#ix (auto1 @TIx) @@ -1020,9 +1010,9 @@ drev des = \case                                   in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t) ibind')                                                           (EVar ext tIx IZ))                                 Right IZ -> EVar ext tIx IZ  -- build lambda index argument -                               Right (IS ienv) -> EVar ext t (IS (wSinks (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) @> ienv))) +                               Right (IS ienv) -> EVar ext t (IS (wSinksAnd (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) (wUndoSubenv subD1eUsed) @> ienv)))                    e1)) -        nsub +        (subenvCompose subMergeUsed proSub)          (ELet ext             (uninvertTup (d2e envPro) (STArr (SS SZ) STNil) $                makeAccumulators @_ @_ @(TArr (S Z) TNil) envPro $ @@ -1035,8 +1025,21 @@ drev des = \case                                          #binds                                          (#pro :++: #d :++: (#ve0 :++: #binds) :++: #tl))                       (EVar ext tIx IZ)) -                  -- TODO: use vectoriseExpr -                  (_ $ +                  (ELet ext (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (d2 eltty)) +                                                      (IS (wSinks @(TArr (S Z) (D2 eltty) : Append (Append (Vectorise (S Z) e_binds) (TIx : ne_binds)) (D2AcE (Select env sto "accum"))) +                                                                  (d2ace envPro) +                                                             @> IZ))) +                                                  (EVar ext tIx IZ))) $ +                   weakenExpr (autoWeak (#i (auto1 @TIx) +                                         &. #dpro (d2ace envPro) +                                         &. #d (d2 eltty `SCons` SNil) +                                         &. #darr (STArr (SS SZ) (d2 eltty) `SCons` SNil) +                                         &. #n (auto1 @TIx) +                                         &. #vbinds (bindingsBinds ve0) +                                         &. #ne0 (bindingsBinds ne0) +                                         &. #tl (d2ace (select SAccum des))) +                                        (#i :++: (#dpro :++: #d) :++: #vbinds :++: #tl) +                                        (#d :++: #i :++: #dpro :++: #darr :++: (#vbinds :++: #n :++: #ne0) :++: #tl)) $                     vectoriseExpr (sappend (d2ace envPro) (d2 eltty `SCons` SNil)) (bindingsBinds e0) (d2ace (select SAccum des)) $                     weakenExpr (autoWeak (#dpro (d2ace envPro)                                           &. #d (d2 eltty `SCons` SNil) @@ -1044,19 +1047,12 @@ drev des = \case                                           &. #tl (d2ace (select SAccum des)))                                          (#dpro :++: #d :++: #binds :++: #tl)                                          ((#dpro :++: #d) :++: #binds :++: #tl)) $ -                   weakenExpr (wPro (bindingsBinds e0)) e2)) $ +                   weakenExpr (wCopies (d2ace envPro) (WCopy @(D2 eltty) (wCopies (bindingsBinds e0) (wUndoSubenv subAccumUsed)))) $ +                   weakenExpr (wPro (bindingsBinds e0)) $ +                   e2)) $           ELet ext (ENil ext) $ -         weakenExpr (autoWeak (#nil (auto1 @TNil) -                               &. #d (auto1 @(D2 t)) -                               &. #nilarr (auto1 @(TArr (S Z) TNil)) -                               &. #ve0 (bindingsBinds ve0) -                               &. #n (auto1 @TIx) -                               &. #binds (bindingsBinds ne0) -                               &. #tl (d2ace (select SAccum des))) -                              (#nil :++: #binds :++: #tl) -                              (#nil :++: #nilarr :++: #d :++: (#ve0 :++: #n :++: #binds) :++: #tl)) -           ne2) -    }}} +         ESnd ext (EVar ext (STPair (STArr (SS SZ) STNil) (tTup (d2e envPro))) (IS IZ))) +    }}    EUnit _ e      | Ret e0 e1 sub e2 <- drev des e -> @@ -1075,9 +1071,20 @@ drev des = \case          (ELet ext (EUnit ext (EVar ext (d2 t) IZ)) $           weakenExpr (WCopy WSink) e2) +  EIdx1 _ e ei +    -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. +    | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) +        <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil +    -> +    Ret binds +        (EIdx1 ext e1 ei1) +        sub +        (_ e2) +    -- These should be the next to be implemented, I think -  EIdx1{} -> err_unsupported "EIdx1"    EFold1{} -> err_unsupported "EFold1" +  EShape{} -> err_unsupported "EShape" +  EReplicate{} -> err_unsupported "EReplicate"    EIdx{} -> err_unsupported "EIdx"    EBuild{} -> err_unsupported "EBuild" diff --git a/src/Example.hs b/src/Example.hs index 572d67e..86264e1 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -107,3 +107,18 @@ ex5 =      (bin (OMul STF32) (EVar ext (STScal STF32) IZ)                        (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ))                                          (EConst ext STF32 1.0))) + +senv6 :: SList STy [TScal TI64, TScal TF32] +senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil + +descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"] +descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge) + +ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) +ex6 = +  ELet ext (EUnit ext (EVar ext (STScal STF32) (IS IZ))) $ +  ELet ext (EBuild1 ext (EVar ext tIx (IS IZ)) $ +              ELet ext (EIdx0 ext (EVar ext (STArr SZ (STScal STF32)) (IS IZ))) $ +                bin (OMul STF32) (EVar ext (STScal STF32) IZ) +                                 (EVar ext (STScal STF32) IZ)) $ +    (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (STScal STF32)) IZ) (EConst ext STI64 3))) diff --git a/src/Simplify.hs b/src/Simplify.hs index f2fc54a..698c667 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -71,9 +71,10 @@ simplify' = \case    EInr _ t e -> EInr ext t (simplify' e)    ECase _ e a b -> ECase ext (simplify' e) (simplify' a) (simplify' b)    EBuild1 _ a b -> EBuild1 ext (simplify' a) (simplify' b) -  EBuild _ es e -> EBuild ext (fmap simplify' es) (simplify' e) +  EBuild _ n a b -> EBuild ext n (simplify' a) (simplify' b)    EFold1 _ a b -> EFold1 ext (simplify' a) (simplify' b)    EUnit _ e -> EUnit ext (simplify' e) +  EReplicate _ e -> EReplicate ext (simplify' e)    EConst _ t v -> EConst ext t v    EIdx0 _ e -> EIdx0 ext (simplify' e)    EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b) @@ -104,9 +105,10 @@ hasAdds = \case    EInr _ _ e -> hasAdds e    ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b    EBuild1 _ a b -> hasAdds a || hasAdds b -  EBuild _ es e -> getAny (foldMap (Any . hasAdds) es) || hasAdds e +  EBuild _ _ a b -> hasAdds a || hasAdds b    EFold1 _ a b -> hasAdds a || hasAdds b    EUnit _ e -> hasAdds e +  EReplicate _ e -> hasAdds e    EConst _ _ _ -> False    EIdx0 _ e -> hasAdds e    EIdx1 _ a b -> hasAdds a || hasAdds b  | 
