diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/AST.hs | 453 | ||||
| -rw-r--r-- | src/AST/Accum.hs | 60 | ||||
| -rw-r--r-- | src/AST/Count.hs | 164 | ||||
| -rw-r--r-- | src/AST/Env.hs | 59 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 137 | ||||
| -rw-r--r-- | src/CHAD.hs | 1131 | ||||
| -rw-r--r-- | src/CHAD/APIv1.hs | 178 | ||||
| -rw-r--r-- | src/CHAD/AST.hs | 709 | ||||
| -rw-r--r-- | src/CHAD/AST/Accum.hs | 137 | ||||
| -rw-r--r-- | src/CHAD/AST/Bindings.hs (renamed from src/AST/Bindings.hs) | 28 | ||||
| -rw-r--r-- | src/CHAD/AST/Count.hs | 927 | ||||
| -rw-r--r-- | src/CHAD/AST/Env.hs | 95 | ||||
| -rw-r--r-- | src/CHAD/AST/Pretty.hs (renamed from src/AST/Pretty.hs) | 189 | ||||
| -rw-r--r-- | src/CHAD/AST/Sparse.hs | 296 | ||||
| -rw-r--r-- | src/CHAD/AST/Sparse/Types.hs | 107 | ||||
| -rw-r--r-- | src/CHAD/AST/SplitLets.hs (renamed from src/AST/SplitLets.hs) | 83 | ||||
| -rw-r--r-- | src/CHAD/AST/Types.hs (renamed from src/AST/Types.hs) | 80 | ||||
| -rw-r--r-- | src/CHAD/AST/UnMonoid.hs | 252 | ||||
| -rw-r--r-- | src/CHAD/AST/Weaken.hs (renamed from src/AST/Weaken.hs) | 14 | ||||
| -rw-r--r-- | src/CHAD/AST/Weaken/Auto.hs (renamed from src/AST/Weaken/Auto.hs) | 57 | ||||
| -rw-r--r-- | src/CHAD/Accum.hs | 27 | ||||
| -rw-r--r-- | src/CHAD/Analysis/Identity.hs (renamed from src/Analysis/Identity.hs) | 135 | ||||
| -rw-r--r-- | src/CHAD/Array.hs (renamed from src/Array.hs) | 12 | ||||
| -rw-r--r-- | src/CHAD/Compile.hs (renamed from src/Compile.hs) | 854 | ||||
| -rw-r--r-- | src/CHAD/Compile/Exec.hs (renamed from src/Compile/Exec.hs) | 33 | ||||
| -rw-r--r-- | src/CHAD/Data.hs (renamed from src/Data.hs) | 10 | ||||
| -rw-r--r-- | src/CHAD/Data/VarMap.hs (renamed from src/Data/VarMap.hs) | 17 | ||||
| -rw-r--r-- | src/CHAD/Drev.hs | 1581 | ||||
| -rw-r--r-- | src/CHAD/Drev/Accum.hs | 72 | ||||
| -rw-r--r-- | src/CHAD/Drev/EnvDescr.hs (renamed from src/CHAD/EnvDescr.hs) | 34 | ||||
| -rw-r--r-- | src/CHAD/Drev/Top.hs (renamed from src/CHAD/Top.hs) | 80 | ||||
| -rw-r--r-- | src/CHAD/Drev/Types.hs | 153 | ||||
| -rw-r--r-- | src/CHAD/Drev/Types/ToTan.hs (renamed from src/CHAD/Types/ToTan.hs) | 37 | ||||
| -rw-r--r-- | src/CHAD/Example.hs (renamed from src/Example.hs) | 57 | ||||
| -rw-r--r-- | src/CHAD/Example/GMM.hs (renamed from src/Example/GMM.hs) | 11 | ||||
| -rw-r--r-- | src/CHAD/Example/Types.hs (renamed from src/Example/Types.hs) | 6 | ||||
| -rw-r--r-- | src/CHAD/ForwardAD.hs (renamed from src/ForwardAD.hs) | 57 | ||||
| -rw-r--r-- | src/CHAD/ForwardAD/DualNumbers.hs (renamed from src/ForwardAD/DualNumbers.hs) | 29 | ||||
| -rw-r--r-- | src/CHAD/ForwardAD/DualNumbers/Types.hs (renamed from src/ForwardAD/DualNumbers/Types.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/Interpreter.hs | 468 | ||||
| -rw-r--r-- | src/CHAD/Interpreter/Accum.hs (renamed from src/Interpreter/Accum.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/Interpreter/AccumOld.hs (renamed from src/Interpreter/AccumOld.hs) | 8 | ||||
| -rw-r--r-- | src/CHAD/Interpreter/Rep.hs (renamed from src/Interpreter/Rep.hs) | 79 | ||||
| -rw-r--r-- | src/CHAD/Language.hs | 423 | ||||
| -rw-r--r-- | src/CHAD/Language/AST.hs (renamed from src/Language/AST.hs) | 136 | ||||
| -rw-r--r-- | src/CHAD/Lemmas.hs (renamed from src/Lemmas.hs) | 2 | ||||
| -rw-r--r-- | src/CHAD/Simplify.hs | 620 | ||||
| -rw-r--r-- | src/CHAD/Simplify/TH.hs | 80 | ||||
| -rw-r--r-- | src/CHAD/Types.hs | 108 | ||||
| -rw-r--r-- | src/CHAD/Util/IdGen.hs (renamed from src/Util/IdGen.hs) | 2 | ||||
| -rw-r--r-- | src/Interpreter.hs | 448 | ||||
| -rw-r--r-- | src/Language.hs | 226 | ||||
| -rw-r--r-- | src/Simplify.hs | 348 |
53 files changed, 7537 insertions, 3788 deletions
diff --git a/src/AST.hs b/src/AST.hs deleted file mode 100644 index b8d23b4..0000000 --- a/src/AST.hs +++ /dev/null @@ -1,453 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DeriveTraversable #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where - -import Data.Functor.Const -import Data.Kind (Type) - -import Array -import AST.Accum -import AST.Types -import AST.Weaken -import CHAD.Types -import Data - - --- General assumption: head of the list (whatever way it is associated) is the --- inner variable / inner array dimension. In pretty printing, the inner --- variable / inner dimension is printed on the _right_. --- --- Note that the 'EZero' and 'EPlus' constructs have typing that depend on the --- type transformation of CHAD. Indeed, these constructors are created _by_ --- CHAD, and are intended to be eliminated after simplification, so that the --- input program as well as the output program do not contain these --- constructors. --- TODO: ensure this by a "stage" type parameter. -type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type -data Expr x env t where - -- lambda calculus - EVar :: x t -> STy t -> Idx env t -> Expr x env t - ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t - - -- base types - EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b) - EFst :: x a -> Expr x env (TPair a b) -> Expr x env a - ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b - ENil :: x TNil -> Expr x env TNil - EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b) - EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b) - ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c - ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) - EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) - EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b - - -- array operations - EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) - EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) - EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) - ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) - EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) - EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - - -- expression operations - EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) - EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t - EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) - EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t - EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) - EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t - - -- custom derivatives - -- 'b' is the part of the input of the operation that derivatives should - -- be backpropagated to; 'a' is the inactive part. The dual field of - -- ECustom does not allow a derivative to be generated for 'a', and hence - -- none is propagated. - ECustom :: x t -> STy a -> STy b -> STy tape - -> Expr x [b, a] t -- ^ regular operation - -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass - -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative - -> Expr x env a -> Expr x env b - -> Expr x env t - - -- accumulation effect on monoids - EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum t : env) a -> Expr x env (TPair a (D2 t)) - EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum t) -> Expr x env TNil - - -- monoidal operations (to be desugared to regular operations after simplification) - EZero :: x (D2 t) -> STy t -> Expr x env (D2 t) - EPlus :: x (D2 t) -> STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t) - EOneHot :: x (D2 t) -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (D2 t) - - -- partiality - EError :: x a -> STy a -> String -> Expr x env a -deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) - -type Ex = Expr (Const ()) - -ext :: Const () a -ext = Const () - -data Commutative = Commut | Noncommut - deriving (Show) - -type SOp :: Ty -> Ty -> Type -data SOp a t where - OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) - OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) - ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a) - OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) - OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) - OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) - ONot :: SOp (TScal TBool) (TScal TBool) - OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) - OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) - OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right - ORound64 :: SOp (TScal TF64) (TScal TI64) - OToFl64 :: SOp (TScal TI64) (TScal TF64) - ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) - OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) - OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) - OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) - OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) -deriving instance Show (SOp a t) - -opt1 :: SOp a t -> STy a -opt1 = \case - OAdd t -> STPair (STScal t) (STScal t) - OMul t -> STPair (STScal t) (STScal t) - ONeg t -> STScal t - OLt t -> STPair (STScal t) (STScal t) - OLe t -> STPair (STScal t) (STScal t) - OEq t -> STPair (STScal t) (STScal t) - ONot -> STScal STBool - OAnd -> STPair (STScal STBool) (STScal STBool) - OOr -> STPair (STScal STBool) (STScal STBool) - OIf -> STScal STBool - ORound64 -> STScal STF64 - OToFl64 -> STScal STI64 - ORecip t -> STScal t - OExp t -> STScal t - OLog t -> STScal t - OIDiv t -> STPair (STScal t) (STScal t) - OMod t -> STPair (STScal t) (STScal t) - -opt2 :: SOp a t -> STy t -opt2 = \case - OAdd t -> STScal t - OMul t -> STScal t - ONeg t -> STScal t - OLt _ -> STScal STBool - OLe _ -> STScal STBool - OEq _ -> STScal STBool - ONot -> STScal STBool - OAnd -> STScal STBool - OOr -> STScal STBool - OIf -> STEither STNil STNil - ORound64 -> STScal STI64 - OToFl64 -> STScal STF64 - ORecip t -> STScal t - OExp t -> STScal t - OLog t -> STScal t - OIDiv t -> STScal t - OMod t -> STScal t - -typeOf :: Expr x env t -> STy t -typeOf = \case - EVar _ t _ -> t - ELet _ _ e -> typeOf e - - EPair _ a b -> STPair (typeOf a) (typeOf b) - EFst _ e | STPair t _ <- typeOf e -> t - ESnd _ e | STPair _ t <- typeOf e -> t - ENil _ -> STNil - EInl _ t2 e -> STEither (typeOf e) t2 - EInr _ t1 e -> STEither t1 (typeOf e) - ECase _ _ a _ -> typeOf a - ENothing _ t -> STMaybe t - EJust _ e -> STMaybe (typeOf e) - EMaybe _ e _ _ -> typeOf e - - EConstArr _ n t _ -> STArr n (STScal t) - EBuild _ n _ e -> STArr n (typeOf e) - EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t - ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t - EUnit _ e -> STArr SZ (typeOf e) - EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t - EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t - EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t - - EConst _ t _ -> STScal t - EIdx0 _ e | STArr _ t <- typeOf e -> t - EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t - EIdx _ e _ | STArr _ t <- typeOf e -> t - EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx) - EOp _ op _ -> opt2 op - - ECustom _ _ _ _ e _ _ _ _ -> typeOf e - - EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) - EAccum _ _ _ _ _ _ -> STNil - - EZero _ t -> d2 t - EPlus _ t _ _ -> d2 t - EOneHot _ t _ _ _ -> d2 t - - EError _ t _ -> t - -extOf :: Expr x env t -> x t -extOf = \case - EVar x _ _ -> x - ELet x _ _ -> x - EPair x _ _ -> x - EFst x _ -> x - ESnd x _ -> x - ENil x -> x - EInl x _ _ -> x - EInr x _ _ -> x - ECase x _ _ _ -> x - ENothing x _ -> x - EJust x _ -> x - EMaybe x _ _ _ -> x - EConstArr x _ _ _ -> x - EBuild x _ _ _ -> x - EFold1Inner x _ _ _ _ -> x - ESum1Inner x _ -> x - EUnit x _ -> x - EReplicate1Inner x _ _ -> x - EMaximum1Inner x _ -> x - EMinimum1Inner x _ -> x - EConst x _ _ -> x - EIdx0 x _ -> x - EIdx1 x _ _ -> x - EIdx x _ _ -> x - EShape x _ -> x - EOp x _ _ -> x - ECustom x _ _ _ _ _ _ _ _ -> x - EWith x _ _ _ -> x - EAccum x _ _ _ _ _ -> x - EZero x _ -> x - EPlus x _ _ _ -> x - EOneHot x _ _ _ _ -> x - EError x _ _ -> x - -mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t -mapExt f = \case - EVar x t i -> EVar (f x) t i - ELet x rhs body -> ELet (f x) (mapExt f rhs) (mapExt f body) - EPair x a b -> EPair (f x) (mapExt f a) (mapExt f b) - EFst x e -> EFst (f x) (mapExt f e) - ESnd x e -> ESnd (f x) (mapExt f e) - ENil x -> ENil (f x) - EInl x t e -> EInl (f x) t (mapExt f e) - EInr x t e -> EInr (f x) t (mapExt f e) - ECase x e a b -> ECase (f x) (mapExt f e) (mapExt f a) (mapExt f b) - ENothing x t -> ENothing (f x) t - EJust x e -> EJust (f x) (mapExt f e) - EMaybe x a b e -> EMaybe (f x) (mapExt f a) (mapExt f b) (mapExt f e) - EConstArr x n t a -> EConstArr (f x) n t a - EBuild x n a b -> EBuild (f x) n (mapExt f a) (mapExt f b) - EFold1Inner x cm a b c -> EFold1Inner (f x) cm (mapExt f a) (mapExt f b) (mapExt f c) - ESum1Inner x e -> ESum1Inner (f x) (mapExt f e) - EUnit x e -> EUnit (f x) (mapExt f e) - EReplicate1Inner x a b -> EReplicate1Inner (f x) (mapExt f a) (mapExt f b) - EMaximum1Inner x e -> EMaximum1Inner (f x) (mapExt f e) - EMinimum1Inner x e -> EMinimum1Inner (f x) (mapExt f e) - EConst x t v -> EConst (f x) t v - EIdx0 x e -> EIdx0 (f x) (mapExt f e) - EIdx1 x a b -> EIdx1 (f x) (mapExt f a) (mapExt f b) - EIdx x e es -> EIdx (f x) (mapExt f e) (mapExt f es) - EShape x e -> EShape (f x) (mapExt f e) - EOp x op e -> EOp (f x) op (mapExt f e) - ECustom x s t p a b c e1 e2 -> ECustom (f x) s t p (mapExt f a) (mapExt f b) (mapExt f c) (mapExt f e1) (mapExt f e2) - EWith x t e1 e2 -> EWith (f x) t (mapExt f e1) (mapExt f e2) - EAccum x t p e1 e2 e3 -> EAccum (f x) t p (mapExt f e1) (mapExt f e2) (mapExt f e3) - EZero x t -> EZero (f x) t - EPlus x t a b -> EPlus (f x) t (mapExt f a) (mapExt f b) - EOneHot x t p a b -> EOneHot (f x) t p (mapExt f a) (mapExt f b) - EError x t s -> EError (f x) t s - -substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t -substInline repl = - subst $ \x t -> \case IZ -> repl - IS i -> EVar x t i - -subst0 :: Ex (b : env) a -> Ex (a : env) t -> Ex (b : env) t -subst0 repl = - subst $ \_ t -> \case IZ -> repl - IS i -> EVar ext t (IS i) - -subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a) - -> Expr x env t -> Expr x env' t -subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId - -subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a) - -> env' :> envOut - -> Expr x env t - -> Expr x envOut t -subst' f w = \case - EVar x t i -> f x t w i - ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body) - EPair x a b -> EPair x (subst' f w a) (subst' f w b) - EFst x e -> EFst x (subst' f w e) - ESnd x e -> ESnd x (subst' f w e) - ENil x -> ENil x - EInl x t e -> EInl x t (subst' f w e) - EInr x t e -> EInr x t (subst' f w e) - ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) - ENothing x t -> ENothing x t - EJust x e -> EJust x (subst' f w e) - EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e) - EConstArr x n t a -> EConstArr x n t a - EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) - EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) - ESum1Inner x e -> ESum1Inner x (subst' f w e) - EUnit x e -> EUnit x (subst' f w e) - EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) - EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) - EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) - EConst x t v -> EConst x t v - EIdx0 x e -> EIdx0 x (subst' f w e) - EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) - EIdx x e es -> EIdx x (subst' f w e) (subst' f w es) - EShape x e -> EShape x (subst' f w e) - EOp x op e -> EOp x op (subst' f w e) - ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) - EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) - EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3) - EZero x t -> EZero x t - EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) - EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) - EError x t s -> EError x t s - where - sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) - -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t - sinkF f' x' t w' = \case - IZ -> EVar x' t (w' @> IZ) - IS i -> f' x' t (WPop w') i - -weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t -weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) - -class KnownScalTy t where knownScalTy :: SScalTy t -instance KnownScalTy TI32 where knownScalTy = STI32 -instance KnownScalTy TI64 where knownScalTy = STI64 -instance KnownScalTy TF32 where knownScalTy = STF32 -instance KnownScalTy TF64 where knownScalTy = STF64 -instance KnownScalTy TBool where knownScalTy = STBool - -class KnownTy t where knownTy :: STy t -instance KnownTy TNil where knownTy = STNil -instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy -instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy -instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy -instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy -instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy -instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy - -class KnownEnv env where knownEnv :: SList STy env -instance KnownEnv '[] where knownEnv = SNil -instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv - -styKnown :: STy t -> Dict (KnownTy t) -styKnown STNil = Dict -styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict -styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict -styKnown (STMaybe t) | Dict <- styKnown t = Dict -styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict -styKnown (STScal t) | Dict <- sscaltyKnown t = Dict -styKnown (STAccum t) | Dict <- styKnown t = Dict - -sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) -sscaltyKnown STI32 = Dict -sscaltyKnown STI64 = Dict -sscaltyKnown STF32 = Dict -sscaltyKnown STF64 = Dict -sscaltyKnown STBool = Dict - -envKnown :: SList STy env -> Dict (KnownEnv env) -envKnown SNil = Dict -envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict - -eTup :: SList (Ex env) list -> Ex env (Tup list) -eTup = mkTup (ENil ext) (EPair ext) - -ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) -ebuildUp1 n sh size f = - EBuild ext (SS n) (EPair ext sh size) $ - let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ - in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) - (EFst ext arg) - -eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) -eidxEq SZ _ _ = EConst ext STBool True -eidxEq (SS SZ) a b = - EOp ext (OEq STI64) (EPair ext (ESnd ext a) (ESnd ext b)) -eidxEq (SS n) a b - | let ty = tTup (sreplicate (SS n) tIx) - = ELet ext a $ - ELet ext (weakenExpr WSink b) $ - EOp ext OAnd $ EPair ext - (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ))) - (ESnd ext (EVar ext ty IZ)))) - (eidxEq n (EFst ext (EVar ext ty (IS IZ))) - (EFst ext (EVar ext ty IZ))) - -emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) -emap f arr = - let STArr n t = typeOf arr - in ELet ext arr $ - EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ - ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) f - -ezipWith :: Ex (b : a : env) c -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) -ezipWith f arr1 arr2 = - let STArr n t1 = typeOf arr1 - STArr _ t2 = typeOf arr2 - in ELet ext arr1 $ - ELet ext (weakenExpr WSink arr2) $ - EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ - ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ - weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f - -ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) -ezip arr1 arr2 = - let STArr _ t1 = typeOf arr1 - STArr _ t2 = typeOf arr2 - in ezipWith (EPair ext (EVar ext t1 (IS IZ)) (EVar ext t2 IZ)) arr1 arr2 - -eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a -eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) - --- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty). -eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) -eshapeEmpty SZ _ = EConst ext STBool False -eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0)) -eshapeEmpty (SS n) e = - ELet ext e $ - EOp ext OAnd (EPair ext - (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)) - (EConst ext STI64 0))) - (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs deleted file mode 100644 index 67c5de7..0000000 --- a/src/AST/Accum.hs +++ /dev/null @@ -1,60 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE UndecidableInstances #-} -module AST.Accum where - -import AST.Types -import Data - - -data AcPrj - = APHere - | APFst AcPrj - | APSnd AcPrj - | APLeft AcPrj - | APRight AcPrj - | APJust AcPrj - | APArrIdx AcPrj - | APArrSlice Nat - --- | @b@ is a small part of @a@, indicated by the projection @p@. -data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where - SAPHere :: SAcPrj APHere a a - SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b - SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b - SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b - SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b - SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b - -- TODO: This SNat is rather useless, you always have an STy around too - SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b - -- TODO: - -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) -deriving instance Show (SAcPrj p a b) - -type family AcIdx p t where - AcIdx APHere t = TNil - AcIdx (APFst p) (TPair a b) = AcIdx p a - AcIdx (APSnd p) (TPair a b) = AcIdx p b - AcIdx (APLeft p) (TEither a b) = AcIdx p a - AcIdx (APRight p) (TEither a b) = AcIdx p b - AcIdx (APJust p) (TMaybe a) = AcIdx p a - AcIdx (APArrIdx p) (TArr n a) = - -- ((index, array shape), recursive info) - TPair (TPair (Tup (Replicate n TIx)) (Tup (Replicate n TIx))) - (AcIdx p a) - -- AcIdx (APArrSlice m) (TArr n a) = - -- -- (index, array shape) - -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) - -acPrjTy :: SAcPrj p a b -> STy a -> STy b -acPrjTy SAPHere t = t -acPrjTy (SAPFst prj) (STPair t _) = acPrjTy prj t -acPrjTy (SAPSnd prj) (STPair _ t) = acPrjTy prj t -acPrjTy (SAPLeft prj) (STEither t _) = acPrjTy prj t -acPrjTy (SAPRight prj) (STEither _ t) = acPrjTy prj t -acPrjTy (SAPJust prj) (STMaybe t) = acPrjTy prj t -acPrjTy (SAPArrIdx prj _) (STArr _ t) = acPrjTy prj t diff --git a/src/AST/Count.hs b/src/AST/Count.hs deleted file mode 100644 index dc8ec72..0000000 --- a/src/AST/Count.hs +++ /dev/null @@ -1,164 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} -module AST.Count where - -import Data.Functor.Const -import GHC.Generics (Generic, Generically(..)) - -import AST -import AST.Env -import Data - - -data Count = Zero | One | Many - deriving (Show, Eq, Ord) - -instance Semigroup Count where - Zero <> n = n - n <> Zero = n - _ <> _ = Many -instance Monoid Count where - mempty = Zero - -data Occ = Occ { _occLexical :: Count - , _occRuntime :: Count } - deriving (Eq, Generic) - deriving (Semigroup, Monoid) via Generically Occ - -instance Show Occ where - showsPrec d (Occ l r) = showParen (d > 10) $ - showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r - --- | One of the two branches is taken -(<||>) :: Occ -> Occ -> Occ -Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2) - --- | This code is executed many times -scaleMany :: Occ -> Occ -scaleMany (Occ l Zero) = Occ l Zero -scaleMany (Occ l _) = Occ l Many - -occCount :: Idx env a -> Expr x env t -> Occ -occCount idx = - getConst . occCountGeneral - (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty) - (\(Const o) -> Const o) - (\(Const o1) (Const o2) -> Const (o1 <||> o2)) - (\(Const o) -> Const (scaleMany o)) - - -data OccEnv env where - OccEnd :: OccEnv env -- not necessarily top! - OccPush :: OccEnv env -> Occ -> OccEnv (t : env) - -instance Semigroup (OccEnv env) where - OccEnd <> e = e - e <> OccEnd = e - OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o') - -instance Monoid (OccEnv env) where - mempty = OccEnd - -onehotOccEnv :: Idx env t -> Occ -> OccEnv env -onehotOccEnv IZ v = OccPush OccEnd v -onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty - -(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env -OccEnd <||>! e = e -e <||>! OccEnd = e -OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o') - -scaleManyOccEnv :: OccEnv env -> OccEnv env -scaleManyOccEnv OccEnd = OccEnd -scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o) - -occEnvPop :: OccEnv (t : env) -> OccEnv env -occEnvPop (OccPush o _) = o -occEnvPop OccEnd = OccEnd - -occCountAll :: Expr x env t -> OccEnv env -occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv - -occCountGeneral :: forall r env t x. - (forall env'. Monoid (r env')) - => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot - -> (forall env' a. r (a : env') -> r env') -- ^ unpush - -> (forall env'. r env' -> r env' -> r env') -- ^ alternation - -> (forall env'. r env' -> r env') -- ^ scale-many - -> Expr x env t -> r env -occCountGeneral onehot unpush alter many = go WId - where - go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env' - go w = \case - EVar _ _ i -> onehot w i (Occ One One) - ELet _ rhs body -> re rhs <> re1 body - EPair _ a b -> re a <> re b - EFst _ e -> re e - ESnd _ e -> re e - ENil _ -> mempty - EInl _ _ e -> re e - EInr _ _ e -> re e - ECase _ e a b -> re e <> (re1 a `alter` re1 b) - ENothing _ _ -> mempty - EJust _ e -> re e - EMaybe _ a b e -> re a <> re1 b <> re e - EConstArr{} -> mempty - EBuild _ _ a b -> re a <> many (re1 b) - EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c - ESum1Inner _ e -> re e - EUnit _ e -> re e - EReplicate1Inner _ a b -> re a <> re b - EMaximum1Inner _ e -> re e - EMinimum1Inner _ e -> re e - EConst{} -> mempty - EIdx0 _ e -> re e - EIdx1 _ a b -> re a <> re b - EIdx _ a b -> re a <> re b - EShape _ e -> re e - EOp _ _ e -> re e - ECustom _ _ _ _ _ _ _ a b -> re a <> re b - EWith _ _ a b -> re a <> re1 b - EAccum _ _ _ a b e -> re a <> re b <> re e - EZero _ _ -> mempty - EPlus _ _ a b -> re a <> re b - EOneHot _ _ _ a b -> re a <> re b - EError{} -> mempty - where - re :: Monoid (r env') => Expr x env' t'' -> r env' - re = go w - - re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env' - re1 = unpush . go (WSink .> w) - - -deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r -deleteUnused SNil OccEnd k = k SETop -deleteUnused (_ `SCons` env) OccEnd k = - deleteUnused env OccEnd $ \sub -> k (SENo sub) -deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k = - deleteUnused env occenv $ \sub -> - case count of Zero -> k (SENo sub) - _ -> k (SEYes sub) - -unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t -unsafeWeakenWithSubenv = \sub -> - subst (\x t i -> case sinkViaSubenv i sub of - Just i' -> EVar x t i' - Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") - where - sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) - sinkViaSubenv IZ (SEYes _) = Just IZ - sinkViaSubenv IZ (SENo _) = Nothing - sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub - sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub diff --git a/src/AST/Env.hs b/src/AST/Env.hs deleted file mode 100644 index 4f34166..0000000 --- a/src/AST/Env.hs +++ /dev/null @@ -1,59 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE ExplicitForAll #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} -module AST.Env where - -import AST.Weaken -import Data - - --- | @env'@ is a subset of @env@: each element of @env@ is either included in --- @env'@ ('SEYes') or not included in @env'@ ('SENo'). -data Subenv env env' where - SETop :: Subenv '[] '[] - SEYes :: forall t env env'. Subenv env env' -> Subenv (t : env) (t : env') - SENo :: forall t env env'. Subenv env env' -> Subenv (t : env) env' -deriving instance Show (Subenv env env') - -subList :: SList f env -> Subenv env env' -> SList f env' -subList SNil SETop = SNil -subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub) -subList (SCons _ xs) (SENo sub) = subList xs sub - -subenvAll :: SList f env -> Subenv env env -subenvAll SNil = SETop -subenvAll (SCons _ env) = SEYes (subenvAll env) - -subenvNone :: SList f env -> Subenv env '[] -subenvNone SNil = SETop -subenvNone (SCons _ env) = SENo (subenvNone env) - -subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t] -subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env) -subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) -subenvOnehot SNil i = case i of {} - -subenvCompose :: Subenv env1 env2 -> Subenv env2 env3 -> Subenv env1 env3 -subenvCompose SETop SETop = SETop -subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2) -subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) -subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) - -subenvConcat :: Subenv env1 env1' -> Subenv env2 env2' -> Subenv (Append env2 env1) (Append env2' env1') -subenvConcat sub1 SETop = sub1 -subenvConcat sub1 (SEYes sub2) = SEYes (subenvConcat sub1 sub2) -subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2) - -sinkWithSubenv :: Subenv env env' -> env0 :> Append env' env0 -sinkWithSubenv SETop = WId -sinkWithSubenv (SEYes sub) = WSink .> sinkWithSubenv sub -sinkWithSubenv (SENo sub) = sinkWithSubenv sub - -wUndoSubenv :: Subenv env env' -> env' :> env -wUndoSubenv SETop = WId -wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub) -wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs deleted file mode 100644 index 0da1afc..0000000 --- a/src/AST/UnMonoid.hs +++ /dev/null @@ -1,137 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE TypeOperators #-} -module AST.UnMonoid (unMonoid, zero, plus) where - -import AST -import CHAD.Types -import Data - - -unMonoid :: Ex env t -> Ex env t -unMonoid = \case - EZero _ t -> zero t - EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) - EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) - - EVar _ t i -> EVar ext t i - ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) - EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) - EFst _ e -> EFst ext (unMonoid e) - ESnd _ e -> ESnd ext (unMonoid e) - ENil _ -> ENil ext - EInl _ t e -> EInl ext t (unMonoid e) - EInr _ t e -> EInr ext t (unMonoid e) - ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) - ENothing _ t -> ENothing ext t - EJust _ e -> EJust ext (unMonoid e) - EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) - EConstArr _ n t x -> EConstArr ext n t x - EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) - EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) - ESum1Inner _ e -> ESum1Inner ext (unMonoid e) - EUnit _ e -> EUnit ext (unMonoid e) - EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) - EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) - EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) - EConst _ t x -> EConst ext t x - EIdx0 _ e -> EIdx0 ext (unMonoid e) - EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) - EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b) - EShape _ e -> EShape ext (unMonoid e) - EOp _ op e -> EOp ext op (unMonoid e) - ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) - EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) - EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e) - EError _ t s -> EError ext t s - -zero :: STy t -> Ex env (D2 t) -zero STNil = ENil ext -zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2)) -zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2)) -zero (STMaybe t) = ENothing ext (d2 t) -zero (STArr SZ t) = ENothing ext (STArr SZ (d2 t)) -zero (STArr n t) = ENothing ext (STArr n (d2 t)) -zero (STScal t) = case t of - STI32 -> ENil ext - STI64 -> ENil ext - STF32 -> EConst ext STF32 0.0 - STF64 -> EConst ext STF64 0.0 - STBool -> ENil ext -zero STAccum{} = error "Accumulators not allowed in input program" - -plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t) -plus STNil _ _ = ENil ext -plus (STPair t1 t2) a b = - let t = STPair (d2 t1) (d2 t2) - in plusSparse t a b $ - EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) - (EFst ext (EVar ext t IZ))) - (plus t2 (ESnd ext (EVar ext t (IS IZ))) - (ESnd ext (EVar ext t IZ))) -plus (STEither t1 t2) a b = - let t = STEither (d2 t1) (d2 t2) - in plusSparse t a b $ - ECase ext (EVar ext t (IS IZ)) - (ECase ext (EVar ext t (IS IZ)) - (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ))) - (EError ext t "plus l+r")) - (ECase ext (EVar ext t (IS IZ)) - (EError ext t "plus r+l") - (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) -plus (STMaybe t) a b = - plusSparse (d2 t) a b $ - plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ) -plus (STArr n t) a b = - plusSparse (STArr n (d2 t)) a b $ - eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ)))) - (EVar ext (STArr n (d2 t)) IZ) - (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ))) - (EVar ext (STArr n (d2 t)) (IS IZ)) - (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)) - (EVar ext (STArr n (d2 t)) (IS IZ)) - (EVar ext (STArr n (d2 t)) IZ))) -plus (STScal t) a b = case t of - STI32 -> ENil ext - STI64 -> ENil ext - STF32 -> EOp ext (OAdd STF32) (EPair ext a b) - STF64 -> EOp ext (OAdd STF64) (EPair ext a b) - STBool -> ENil ext -plus STAccum{} _ _ = error "Accumulators not allowed in input program" - -plusSparse :: STy a - -> Ex env (TMaybe a) -> Ex env (TMaybe a) - -> Ex (a : a : env) a - -> Ex env (TMaybe a) -plusSparse t a b adder = - ELet ext b $ - EMaybe ext - (EVar ext (STMaybe t) IZ) - (EJust ext - (EMaybe ext - (EVar ext t IZ) - (weakenExpr (WCopy (WCopy WSink)) adder) - (EVar ext (STMaybe t) (IS IZ)))) - (weakenExpr WSink a) - -onehot :: STy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env (D2 a) -> Ex env (D2 t) -onehot typ topprj idx arg = case (typ, topprj) of - (_, SAPHere) -> arg - - (STPair t1 t2, SAPFst prj) -> EJust ext (EPair ext (onehot t1 prj idx arg) (zero t2)) - (STPair t1 t2, SAPSnd prj) -> EJust ext (EPair ext (zero t1) (onehot t2 prj idx arg)) - - (STEither t1 t2, SAPLeft prj) -> EJust ext (EInl ext (d2 t2) (onehot t1 prj idx arg)) - (STEither t1 t2, SAPRight prj) -> EJust ext (EInr ext (d2 t1) (onehot t2 prj idx arg)) - - (STMaybe t1, SAPJust prj) -> EJust ext (onehot t1 prj idx arg) - - (STArr n t1, SAPArrIdx prj _) -> - let tidx = tTup (sreplicate n tIx) - in ELet ext idx $ - EJust ext $ - EBuild ext n (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ))) $ - eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) - (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) - (zero t1) diff --git a/src/CHAD.hs b/src/CHAD.hs deleted file mode 100644 index 1126fde..0000000 --- a/src/CHAD.hs +++ /dev/null @@ -1,1131 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE OverloadedLabels #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} - --- I want to bring various type variables in scope using type annotations in --- patterns, but I don't want to have to mention all the other type parameters --- of the types in question as well then. Partial type signatures (with '_') are --- useful here. -{-# LANGUAGE PartialTypeSignatures #-} -{-# OPTIONS -Wno-partial-type-signatures #-} -module CHAD ( - drev, - freezeRet, - CHADConfig(..), - defaultConfig, - Storage(..), - Descr(..), - Select, -) where - -import Data.Functor.Const -import Data.Some -import Data.Type.Bool (If) -import Data.Type.Equality (type (==), testEquality) -import GHC.Stack (HasCallStack) - -import Analysis.Identity (ValId(..), validSplitEither) -import AST -import AST.Bindings -import AST.Count -import AST.Env -import AST.Weaken.Auto -import CHAD.Accum -import CHAD.EnvDescr -import CHAD.Types -import Data -import qualified Data.VarMap as VarMap -import Data.VarMap (VarMap) -import Lemmas - - ------------------------------- TAPES AND BINDINGS ------------------------------ - -type family Tape binds where - Tape '[] = TNil - Tape (t : ts) = TPair t (Tape ts) - -tapeTy :: SList STy binds -> STy (Tape binds) -tapeTy SNil = STNil -tapeTy (SCons t ts) = STPair t (tapeTy ts) - -bindingsCollect :: Bindings f env binds -> Subenv binds tapebinds - -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) -bindingsCollect BTop SETop _ = ENil ext -bindingsCollect (BPush binds (t, _)) (SEYes sub) w = - EPair ext (EVar ext t (w @> IZ)) - (bindingsCollect binds sub (w .> WSink)) -bindingsCollect (BPush binds _) (SENo sub) w = - bindingsCollect binds sub (w .> WSink) - --- In order from large to small: i.e. in reverse order from what we want, --- because in a Bindings, the head of the list is the bottom-most entry. -type family TapeUnfoldings binds where - TapeUnfoldings '[] = '[] - TapeUnfoldings (t : ts) = Tape ts : TapeUnfoldings ts - -type family Reverse l where - Reverse '[] = '[] - Reverse (t : ts) = Append (Reverse ts) '[t] - --- An expression that is always 'snd' -data UnfExpr env t where - UnfExSnd :: STy s -> STy t -> UnfExpr (TPair s t : env) t - -fromUnfExpr :: UnfExpr env t -> Ex env t -fromUnfExpr (UnfExSnd s t) = ESnd ext (EVar ext (STPair s t) IZ) - --- - A bunch of 'snd' expressions taking us from knowing that there's a --- 'Tape ts' in the environment (for simplicity assume it's at IZ, we'll fix --- this in reconstructBindings), to having 'Reverse (TapeUnfoldings ts)' in --- the environment. --- - In the extended environment, another bunch of let bindings (these are --- 'fst' expressions, but no need to know that statically) that project the --- fsts out of what we introduced above, one for each type in 'ts'. -data Reconstructor env ts = - Reconstructor - (Bindings UnfExpr (Tape ts : env) (Reverse (TapeUnfoldings ts))) - (Bindings Ex (Append (Reverse (TapeUnfoldings ts)) (Tape ts : env)) ts) - -ssnoc :: SList f ts -> f t -> SList f (Append ts '[t]) -ssnoc SNil a = SCons a SNil -ssnoc (SCons t ts) a = SCons t (ssnoc ts a) - -sreverse :: SList f ts -> SList f (Reverse ts) -sreverse SNil = SNil -sreverse (SCons t ts) = ssnoc (sreverse ts) t - -stapeUnfoldings :: SList STy ts -> SList STy (TapeUnfoldings ts) -stapeUnfoldings SNil = SNil -stapeUnfoldings (SCons _ ts) = SCons (tapeTy ts) (stapeUnfoldings ts) - --- Puts a 'snd' at the top of an unfolder stack and grows the context variable by one. -shiftUnfolder - :: STy t - -> SList STy ts - -> Bindings UnfExpr (Tape ts : env) list - -> Bindings UnfExpr (Tape (t : ts) : env) (Append list '[Tape ts]) -shiftUnfolder newTy ts BTop = BPush BTop (tapeTy ts, UnfExSnd newTy (tapeTy ts)) -shiftUnfolder newTy ts (BPush b (t, UnfExSnd itemTy _)) = - -- Recurse on 'b', and retype the 'snd'. We need to unfold 'b' once in order - -- to expand an 'Append' in the types so that things simplify just enough. - -- We have an equality 'Append binds x1 ~ a : x2', where 'binds' is the list - -- of bindings produced by 'b'. We want to conclude from this that - -- 'binds ~ a : x3' for some 'x3', but GHC will only do that once we know - -- that 'binds ~ y : ys' so that the 'Append' can expand one step, after - -- which 'y ~ a' as desired. The 'case' unfolds 'b' one step. - BPush (shiftUnfolder newTy ts b) (t, case b of BTop -> UnfExSnd itemTy t - BPush{} -> UnfExSnd itemTy t) - -growRecon :: forall env t ts. STy t -> SList STy ts -> Reconstructor env ts -> Reconstructor env (t : ts) -growRecon t ts (Reconstructor unfbs bs) - | Refl <- lemAppendNil @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) - , Refl <- lemAppendAssoc @ts @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) @(Tape (t : ts) : env) - , Refl <- lemAppendAssoc @(Reverse (TapeUnfoldings ts)) @'[Tape ts] @env - = Reconstructor - (shiftUnfolder t ts unfbs) - -- Add a 'fst' at the bottom of the builder stack. - -- First we have to weaken most of 'bs' to skip one more binding in the - -- unfolder stack above it. - (BPush (fst (weakenBindings weakenExpr - (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil)) - (WSink :: env :> (Tape (t : ts) : env))) bs)) - (t - ,EFst ext $ EVar ext (tapeTy (SCons t ts)) $ - wSinks @(Tape (t : ts) : env) - (sappend ts - (sappend (sappend (sreverse (stapeUnfoldings ts)) - (SCons (tapeTy ts) SNil)) - SNil)) - @> IZ)) - -buildReconstructor :: SList STy ts -> Reconstructor env ts -buildReconstructor SNil = Reconstructor BTop BTop -buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts) - --- STRATEGY FOR reconstructBindings --- --- binds = [] --- e : () --- --- binds = [c] --- e : (c, ()) --- x0 = snd x1 : () --- y1 = fst e : c --- --- binds = [b, c] --- e : (b, (c, ())) --- x1 = snd e : (c, ()) --- x0 = snd x1 : () --- y1 = fst x1 : c --- y2 = fst x2 : b --- --- binds = [a, b, c] --- e : (a, (b, (c, ()))) --- x2 = snd e : (b, (c, ())) --- x1 = snd x2 : (c, ()) --- x0 = snd x1 : () --- y1 = fst x1 : c --- y2 = fst x2 : b --- y3 = fst x3 : a - --- Given that in 'env' we can find a 'Tape binds', i.e. a tuple containing all --- the things in the list 'binds', we want to create a let stack that extracts --- all values from that tuple and in effect "restores" the environment --- described by 'binds'. The idea is that elsewhere, we took a slice of the --- environment and saved it all in a tuple to be restored later. We --- incidentally also add a bunch of additional bindings, namely 'Reverse --- (TapeUnfoldings binds)', so the calling code just has to skip those in --- whatever it wants to do. -reconstructBindings :: SList STy binds -> Idx env (Tape binds) - -> (Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) - ,SList STy (Reverse (TapeUnfoldings binds))) -reconstructBindings binds tape = - let Reconstructor unf build = buildReconstructor binds - in (fst $ weakenBindings weakenExpr (WIdx tape) - (bconcat (mapBindings fromUnfExpr unf) build) - ,sreverse (stapeUnfoldings binds)) - - ----------------------------------- DERIVATIVES --------------------------------- - -d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) -d1op (OAdd t) e = EOp ext (OAdd t) e -d1op (OMul t) e = EOp ext (OMul t) e -d1op (ONeg t) e = EOp ext (ONeg t) e -d1op (OLt t) e = EOp ext (OLt t) e -d1op (OLe t) e = EOp ext (OLe t) e -d1op (OEq t) e = EOp ext (OEq t) e -d1op ONot e = EOp ext ONot e -d1op OAnd e = EOp ext OAnd e -d1op OOr e = EOp ext OOr e -d1op OIf e = EOp ext OIf e -d1op ORound64 e = EOp ext ORound64 e -d1op OToFl64 e = EOp ext OToFl64 e -d1op (ORecip t) e = EOp ext (ORecip t) e -d1op (OExp t) e = EOp ext (OExp t) e -d1op (OLog t) e = EOp ext (OLog t) e -d1op (OIDiv t) e = EOp ext (OIDiv t) e -d1op (OMod t) e = EOp ext (OMod t) e - --- | Both primal and dual must be duplicable expressions -data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) - | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)) - -d2op :: SOp a t -> D2Op a t -d2op op = case op of - OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d) - OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> - EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) - (EOp ext (OMul t) (EPair ext (EFst ext e) d))) - ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d - OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) - ONot -> Linear $ \_ -> ENil ext - OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - OIf -> Linear $ \_ -> ENil ext - ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 - OToFl64 -> Linear $ \_ -> ENil ext - ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) - OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) - OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) - OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) - OMod t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) - where - d2opUnArrangeInt :: SScalTy a - -> (D2s a ~ TScal a => D2Op (TScal a) t) - -> D2Op (TScal a) t - d2opUnArrangeInt ty float = case ty of - STI32 -> Linear $ \_ -> ENil ext - STI64 -> Linear $ \_ -> ENil ext - STF32 -> float - STF64 -> float - STBool -> Linear $ \_ -> ENil ext - - d2opBinArrangeInt :: SScalTy a - -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) - -> D2Op (TPair (TScal a) (TScal a)) t - d2opBinArrangeInt ty float = case ty of - STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - STF32 -> float - STF64 -> float - STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil) - - floatingD2 :: ScalIsFloating a ~ True - => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r - floatingD2 STF32 k = k - floatingD2 STF64 k = k - - integralD2 :: ScalIsIntegral a ~ True - => SScalTy a -> ((D2s a ~ TNil, ScalIsNumeric a ~ True) => r) -> r - integralD2 STI32 k = k - integralD2 STI64 k = k - -desD1E :: Descr env sto -> SList STy (D1E env) -desD1E = d1e . descrList - --- d1W :: env :> env' -> D1E env :> D1E env' --- d1W WId = WId --- d1W WSink = WSink --- d1W (WCopy w) = WCopy (d1W w) --- d1W (WPop w) = WPop (d1W w) --- d1W (WThen u w) = WThen (d1W u) (d1W w) - -conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) -conv1Idx IZ = IZ -conv1Idx (IS i) = IS (conv1Idx i) - -data Idx2 env sto t - = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum t)) - | Idx2Me (Idx (Select env sto "merge") t) - | Idx2Di (Idx (Select env sto "discr") t) - -conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t -conv2Idx (DPush _ (_, _, SAccum)) IZ = Idx2Ac IZ -conv2Idx (DPush _ (_, _, SMerge)) IZ = Idx2Me IZ -conv2Idx (DPush _ (_, _, SDiscr)) IZ = Idx2Di IZ -conv2Idx (DPush des (_, _, SAccum)) (IS i) = - case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j) - Idx2Me j -> Idx2Me j - Idx2Di j -> Idx2Di j -conv2Idx (DPush des (_, _, SMerge)) (IS i) = - case conv2Idx des i of Idx2Ac j -> Idx2Ac j - Idx2Me j -> Idx2Me (IS j) - Idx2Di j -> Idx2Di j -conv2Idx (DPush des (_, _, SDiscr)) (IS i) = - case conv2Idx des i of Idx2Ac j -> Idx2Ac j - Idx2Me j -> Idx2Me j - Idx2Di j -> Idx2Di (IS j) -conv2Idx DTop i = case i of {} - - ------------------------------------- MONOIDS ----------------------------------- - -zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) -zeroTup SNil = ENil ext -zeroTup (SCons t env) = EPair ext (zeroTup env) (EZero ext t) - - ------------------------------------- SUBENVS ----------------------------------- - -subenvPlus :: SList STy env - -> Subenv env env1 -> Subenv env env2 - -> (forall env3. Subenv env env3 - -> Subenv env3 env1 - -> Subenv env3 env2 - -> (Ex exenv (Tup (D2E env1)) - -> Ex exenv (Tup (D2E env2)) - -> Ex exenv (Tup (D2E env3))) - -> r) - -> r -subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext) -subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SENo sub3) s31 s32 pl -subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 -> - ELet ext e1 $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) - (weakenExpr WSink e2)) - (ESnd ext (EVar ext (typeOf e1) IZ)) -subenvPlus (SCons _ env) (SENo sub1) (SEYes sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SENo s31) (SEYes s32) $ \e1 e2 -> - ELet ext e2 $ - EPair ext (pl (weakenExpr WSink e1) - (EFst ext (EVar ext (typeOf e2) IZ))) - (ESnd ext (EVar ext (typeOf e2) IZ)) -subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SEYes sub3) (SEYes s31) (SEYes s32) $ \e1 e2 -> - ELet ext e1 $ - ELet ext (weakenExpr WSink e2) $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) - (EFst ext (EVar ext (typeOf e2) IZ))) - (EPlus ext t - (ESnd ext (EVar ext (typeOf e1) (IS IZ))) - (ESnd ext (EVar ext (typeOf e2) IZ))) - -expandSubenvZeros :: SList STy env0 -> Subenv env0 env0Merge -> Ex env (Tup (D2E env0Merge)) -> Ex env (Tup (D2E env0)) -expandSubenvZeros _ SETop _ = ENil ext -expandSubenvZeros (SCons t ts) (SEYes sub) e = - ELet ext e $ - let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ - in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) -expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (EZero ext t) - -assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] -assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl -assertSubenvEmpty SETop = Refl -assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" - - ---------------------------------- ACCUMULATORS --------------------------------- - -fromArrayValId :: Maybe (ValId t) -> Maybe Int -fromArrayValId (Just (VIArr i _)) = Just i -fromArrayValId _ = Nothing - -accumPromote :: forall dt env sto proxy r. - proxy dt - -> Descr env sto - -> (forall stoRepl envPro. - (Select env stoRepl "merge" ~ '[]) - => Descr env stoRepl - -- ^ A revised environment description that switches - -- arrays (used in the OccEnv) that are currently on - -- "merge" storage, to "accum" storage. - -> SList STy envPro - -- ^ New entries on top of the original dual environment, - -- that house the accumulators for the promoted arrays in - -- the original environment. - -> Subenv (Select env sto "merge") envPro - -- ^ The promoted entries were merge entries in the - -- original environment. - -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum")) - -- ^ All entries that were accumulators are still - -- accumulators. - -> VarMap Int (D2AcE (Select env stoRepl "accum")) - -- ^ Accumulator map for _only_ the the newly allocated - -- accumulators. - -> (forall shbinds. - SList STy shbinds - -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) - :> Append (D2AcE envPro) (D2 dt : Append shbinds (D2AcE (Select env sto "accum")))) - -- ^ A weakening that converts a computation in the - -- revised environment to one in the original environment - -- extended with some accumulators. - -> r) - -> r -accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId) -accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of - -- Accumulators are left as-is - SAccum -> - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> - k (storepl `DPush` (t, vid, SAccum)) - envpro - prosub - (SEYes accrevsub) - (VarMap.sink1 accumMap) - (\shbinds -> - autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum descr))) - (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) - (#pro :++: #d :++: #shb :++: #acc :++: #tl) - .> WCopy (wf shbinds) - .> autoWeak (#d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(TAccum t)) &. #tl (d2ace (select SAccum storepl))) - (#d :++: #shb :++: #acc :++: #tl) - (#acc :++: (#d :++: #shb :++: #tl))) - - SMerge -> case t of - -- Discrete values are left as-is - _ | isDiscrete t -> - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf -> - k (storepl `DPush` (t, vid, SDiscr)) - envpro - (SENo prosub) - accrevsub - accumMap' - wf - - -- Values with "merge" storage are promoted to an accumulator in envPro - _ -> - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> - k (storepl `DPush` (t, vid, SAccum)) - (t `SCons` envpro) - (SEYes prosub) - (SENo accrevsub) - (let accumMap' = VarMap.sink1 accumMap - in case fromArrayValId vid of - Just i -> VarMap.insert i (STAccum t) IZ accumMap' - Nothing -> accumMap') - (\(shbinds :: SList _ shbinds) -> - let shbindsC = slistMap (\_ -> Const ()) shbinds - in - -- wf: - -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - -- WCopy wf: - -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - -- WPICK: ^ THESE TWO || - -- goal: | ARE EQUAL || - -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - WCopy (wf shbinds) - .> WPick @(TAccum t) @(D2 dt : shbinds) (Const () `SCons` shbindsC) - (WId @(D2AcE (Select env1 stoRepl "accum")))) - - -- Discrete values are left as-is, nothing to do - SDiscr -> - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> - k (storepl `DPush` (t, vid, SDiscr)) - envpro - prosub - accrevsub - accumMap - wf - where - isDiscrete :: STy t' -> Bool - isDiscrete = \case - STNil -> True - STPair a b -> isDiscrete a && isDiscrete b - STEither a b -> isDiscrete a && isDiscrete b - STMaybe a -> isDiscrete a - STArr _ a -> isDiscrete a - STScal st -> case st of - STI32 -> True - STI64 -> True - STF32 -> False - STF64 -> False - STBool -> True - STAccum{} -> False - - ----------------------------- RETURN TRIPLE FROM CHAD --------------------------- - -data Ret env0 sto t = - forall shbinds tapebinds env0Merge. - Ret (Bindings Ex (D1E env0) shbinds) -- shared binds - (Subenv shbinds tapebinds) - (Ex (Append shbinds (D1E env0)) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) -deriving instance Show (Ret env0 sto t) - -data RetPair env0 sto env shbinds tapebinds t = - forall env0Merge. - RetPair (Ex (Append shbinds env) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) -deriving instance Show (RetPair env0 sto env shbinds tapebinds t) - -data Rets env0 sto env list = - forall shbinds tapebinds. - Rets (Bindings Ex env shbinds) - (Subenv shbinds tapebinds) - (SList (RetPair env0 sto env shbinds tapebinds) list) -deriving instance Show (Rets env0 sto env list) - -weakenRetPair :: SList STy shbinds -> env :> env' - -> RetPair env0 sto env shbinds tapebinds t -> RetPair env0 sto env' shbinds tapebinds t -weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2 - -weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list -weakenRets w (Rets binds tapesub list) = - let (binds', _) = weakenBindings weakenExpr w binds - in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) - -rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto t f. - Descr env0 sto - -> SList f b1 -> SList f b2 - -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2 - -> RetPair env0 sto (Append b1 env) b2 tapebinds2 t - -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) t -rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair p sub d) - | Refl <- lemAppendAssoc @b2 @b1 @env = - RetPair p sub (weakenExpr (autoWeak - (#d (auto1 @(D2 t)) - &. #t2 (subList b2 subtape2) - &. #t1 (subList b1 subtape1) - &. #tl (d2ace (select SAccum descr))) - (#d :++: (#t2 :++: #tl)) - (#d :++: ((#t2 :++: #t1) :++: #tl))) - d) - -retConcat :: forall env0 sto list. Descr env0 sto -> SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list -retConcat _ SNil = Rets BTop SETop SNil -retConcat descr (SCons (Ret (b :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) p sub d) list) - | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs - <- weakenRets (sinkWithBindings b) (retConcat descr list) - , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0) - , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum")) - = Rets (bconcat b binds) - (subenvConcat subtape subtape2) - (SCons (RetPair (weakenExpr (sinkWithBindings binds) p) - sub - (weakenExpr (WCopy (sinkWithSubenv subtape2)) d)) - (slistMap (rebaseRetPair descr (bindingsBinds b) (bindingsBinds binds) - subtape subtape2) - pairs)) - -freezeRet :: Descr env sto - -> Ret env sto t - -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) -freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = - let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0 - e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 - in letBinds e0' $ - EPair ext - (weakenExpr wInsertD2Ac e1) - (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #tape (subList (bindingsBinds e0) subtape) - &. #shbinds (bindingsBinds e0) - &. #d2ace (d2ace (select SAccum descr)) - &. #tl (desD1E descr)) - (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) - (#shbinds :++: #d :++: #d2ace :++: #tl)) - e2') $ - expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) - - ----------------------------- THE CHAD TRANSFORMATION --------------------------- - -drev :: forall env sto t. - (?config :: CHADConfig) - => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) - -> Expr ValId env t -> Ret env sto t -drev des accumMap = \case - EVar _ t i -> - case conv2Idx des i of - Idx2Ac accI -> - Ret BTop - SETop - (EVar ext (d1 t) (conv1Idx i)) - (subenvNone (select SMerge des)) - (EAccum ext t SAPHere (ENil ext) (EVar ext (d2 t) IZ) (EVar ext (STAccum t) (IS accI))) - - Idx2Me tupI -> - Ret BTop - SETop - (EVar ext (d1 t) (conv1Idx i)) - (subenvOnehot (select SMerge des) tupI) - (EPair ext (ENil ext) (EVar ext (d2 t) IZ)) - - Idx2Di _ -> - Ret BTop - SETop - (EVar ext (d1 t) (conv1Idx i)) - (subenvNone (select SMerge des)) - (ENil ext) - - ELet _ (rhs :: Expr _ _ a) body - | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des accumMap rhs - , ChosenStorage storage <- if chcLetArrayAccum ?config && hasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge - , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) body - , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 - , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) - , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) -> - subenvPlus (select SMerge des) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> - let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody))) (d2 (typeOf rhs)) in - Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') - (subenvConcat (SENo @d1_a subtapeRHS) subtapeBody) - (weakenExpr wbody0' body1) - subBoth - (ELet ext (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds body0) subtapeBody) - &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) - &. #tl (d2ace (select SAccum des))) - (#d :++: #body :++: #tl) - (#d :++: (#body :++: #rhs) :++: #tl)) - body2) $ - ELet ext - (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ - weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ - plus_RHS_Body - (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) - (EFst ext (EVar ext bodyResType (IS IZ)))) - - EPair _ a b - | Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) - <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil - , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> - subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B -> - Ret binds - subtape - (EPair ext a1 b1) - subBoth - (EMaybe ext - (zeroTup (subList (select SMerge des) subBoth)) - (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) - (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $ - ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) - (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $ - plus_A_B - (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) - (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)) - (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)) - - EFst _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STPair t1 t2 <- typeOf e -> - Ret e0 - subtape - (EFst ext e1) - sub - (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (EZero ext t2))) $ - weakenExpr (WCopy WSink) e2) - - ESnd _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STPair t1 t2 <- typeOf e -> - Ret e0 - subtape - (ESnd ext e1) - sub - (ELet ext (EJust ext (EPair ext (EZero ext t1) (EVar ext (d2 t2) IZ))) $ - weakenExpr (WCopy WSink) e2) - - ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) - - EInl _ t2 e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> - Ret e0 - subtape - (EInl ext (d1 t2) e1) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ) - (weakenExpr (WCopy (wSinks' @[_,_])) e2) - (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")) - (EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ)) - - EInr _ t1 e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> - Ret e0 - subtape - (EInr ext (d1 t1) e1) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ) - (EError ext (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") - (weakenExpr (WCopy (wSinks' @[_,_])) e2)) - (EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ)) - - ECase _ e (a :: Expr _ _ t) b - | STEither t1 t2 <- typeOf e - , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap e - , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && hasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge - , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && hasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge - , let (bindids1, bindids2) = validSplitEither (extOf e) - , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA a2 <- drevScoped des accumMap t1 storage1 bindids1 a - , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB b2 <- drevScoped des accumMap t2 storage2 bindids2 b - , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) - , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) - , let tapeA = tapeTy (subList (bindingsBinds a0) subtapeA) - , let tapeB = tapeTy (subList (bindingsBinds b0) subtapeB) - , let collectA = bindingsCollect a0 subtapeA - , let collectB = bindingsCollect b0 subtapeB - , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) - , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0 - , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 - -> - subenvPlus (select SMerge des) subA subB $ \subAB sAB_A sAB_B _ -> - subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E -> - let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STEither (d2 t1) (d2 t2)) in - Ret (e0 `BPush` - (tPrimal, - ECase ext e1 - (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0')))) - (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0')))))) - (SEYes subtapeE) - (EFst ext (EVar ext tPrimal IZ)) - subOut - (ELet ext - (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) - (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds a0) subtapeA) IZ - in letBinds rebinds $ - ELet ext - (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds a0) subtapeA) prerebinds) @> IS IZ)) $ - ELet ext - (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #ta0 (subList (bindingsBinds a0) subtapeA) - &. #prea0 prerebinds - &. #recon (tapeA `SCons` d2 (typeOf a) `SCons` SNil) - &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) - &. #tl (d2ace (select SAccum des))) - (#d :++: #ta0 :++: #tl) - (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl)) - a2) $ - EPair ext - (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $ - EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)) - (EInl ext (d2 t2) - (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA))) (d2 t1)) IZ)))) - (let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds b0) subtapeB) IZ - in letBinds rebinds $ - ELet ext - (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_tape : D2 t : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend (subList (bindingsBinds b0) subtapeB) prerebinds) @> IS IZ)) $ - ELet ext - (weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #tb0 (subList (bindingsBinds b0) subtapeB) - &. #preb0 prerebinds - &. #recon (tapeB `SCons` d2 (typeOf a) `SCons` SNil) - &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) - &. #tl (d2ace (select SAccum des))) - (#d :++: #tb0 :++: #tl) - (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl)) - b2) $ - EPair ext - (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ - EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ)) - (EInr ext (d2 t1) - (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $ - ELet ext - (ELet ext (EJust ext (ESnd ext (EVar ext tCaseRet IZ))) $ - weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $ - plus_AB_E - (EFst ext (EVar ext tCaseRet (IS IZ))) - (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ)) - - EConst _ t val -> - Ret BTop - SETop - (EConst ext t val) - (subenvNone (select SMerge des)) - (ENil ext) - - EOp _ op e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> - case d2op op of - Linear d2opfun -> - Ret e0 - subtape - (d1op op e1) - sub - (ELet ext (d2opfun (EVar ext (d2 (opt2 op)) IZ)) - (weakenExpr (WCopy WSink) e2)) - Nonlinear d2opfun -> - Ret (e0 `BPush` (d1 (typeOf e), e1)) - (SEYes subtape) - (d1op op $ EVar ext (d1 (typeOf e)) IZ) - sub - (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) - (EVar ext (d2 (opt2 op)) IZ)) - (weakenExpr (WCopy (wSinks' @[_,_])) e2)) - - ECustom _ _ _ storety _ pr du a b - -- allowed to ignore a2 because 'a' is the part of the input that is inactive - | Rets binds subtape (RetPair a1 _ _ `SCons` RetPair b1 bsub b2 `SCons` SNil) - <- retConcat des $ drev des accumMap a `SCons` drev des accumMap b `SCons` SNil -> - Ret (binds `BPush` (typeOf a1, a1) - `BPush` (typeOf b1, weakenExpr WSink b1) - `BPush` (typeOf pr, weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) - `BPush` (storety, ESnd ext (EVar ext (typeOf pr) IZ))) - (SEYes (SENo (SENo (SENo subtape)))) - (EFst ext (EVar ext (typeOf pr) (IS IZ))) - bsub - (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ - weakenExpr (WCopy (WSink .> WSink)) b2) - - EError _ t s -> - Ret BTop - SETop - (EError ext (d1 t) s) - (subenvNone (select SMerge des)) - (ENil ext) - - EConstArr _ n t val -> - Ret BTop - SETop - (EConstArr ext n t val) - (subenvNone (select SMerge des)) - (ENil ext) - - EBuild _ (ndim :: SNat ndim) she (orige :: Expr _ _ eltty) - | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des accumMap she -- allowed to ignore she2 here because she has a discrete result - , let eltty = typeOf orige - , shty :: STy shty <- tTup (sreplicate ndim tIx) - , Refl <- indexTupD1Id ndim -> - deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> - let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in - subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> - accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> - let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in - case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 sub e2 -> - case assertSubenvEmpty sub of { Refl -> - let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in - let collectexpr = bindingsCollect e0 subtapeE in - Ret (BTop `BPush` (shty, letBinds she0 she1) - `BPush` (STArr ndim (STPair (d1 eltty) tapety) - ,EBuild ext ndim - (EVar ext shty IZ) - (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#ix :++: #sh :++: #d1env)) - e0)) $ - let w = autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #e0 (bindingsBinds e0) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#e0 :++: #ix :++: #sh :++: #d1env) - in EPair ext (weakenExpr w e1) (collectexpr w))) - `BPush` (STArr ndim tapety, emap (ESnd ext (EVar ext (STPair (d1 eltty) tapety) IZ)) - (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))) - (SEYes (SENo (SEYes SETop))) - (emap (EFst ext (EVar ext (STPair (d1 eltty) tapety) IZ)) - (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) - (subenvCompose subMergeUsed proSub) - (let sinkOverEnvPro = wSinks @(TArr ndim (D2 eltty) : D2 t : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : D2AcE (Select env sto "accum")) (d2ace envPro) in - EMaybe ext - (zeroTup envPro) - (ESnd ext $ - uninvertTup (d2e envPro) (STArr ndim STNil) $ - makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ - EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS (IS IZ)))) $ - -- the cotangent for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) - (EVar ext shty IZ)) $ - -- the tape for this element - ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS (IS IZ))) - (EVar ext shty (IS IZ))) $ - let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) IZ - in letBinds rebinds $ - weakenExpr (autoWeak (#d (auto1 @(D2 eltty)) - &. #pro (d2ace envPro) - &. #etape (subList (bindingsBinds e0) subtapeE) - &. #prerebinds prerebinds - &. #tape (auto1 @(Tape e_tape)) - &. #ix (auto1 @shty) - &. #darr (auto1 @(TArr ndim (D2 eltty))) - &. #mdarr (auto1 @(TMaybe (TArr ndim (D2 eltty)))) - &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) - &. #sh (auto1 @shty) - &. #d2acUsed (d2ace (select SAccum usedDes)) - &. #d2acEnv (d2ace (select SAccum des))) - (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #mdarr :++: #tapearr :++: #sh :++: #d2acEnv) - .> wPro (subList (bindingsBinds e0) subtapeE)) - e2) - (EVar ext (d2 (STArr ndim eltty)) IZ)) - }} - - EUnit _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e -> - Ret e0 - subtape - (EUnit ext e1) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EIdx0 ext (EVar ext (STArr SZ (d2 (typeOf e))) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - (EVar ext (STMaybe (STArr SZ (d2 (typeOf e)))) IZ)) - - EReplicate1Inner _ en e - -- We're allowed to ignore en2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair en1 _ _ `SCons` RetPair e1 sub e2 `SCons` SNil) - <- retConcat des $ drev des accumMap en `SCons` drev des accumMap e `SCons` SNil - , let STArr ndim eltty = typeOf e -> - Ret binds - subtape - (EReplicate1Inner ext en1 e1) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext (EFold1Inner ext Commut - (EPlus ext eltty (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) - (EZero ext eltty) - (EVar ext (STArr (SS ndim) (d2 eltty)) IZ))) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - (EVar ext (d2 (STArr (SS ndim) eltty)) IZ)) - - EIdx0 _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STArr _ t <- typeOf e -> - Ret e0 - subtape - (EIdx0 ext e1) - sub - (ELet ext (EJust ext (EUnit ext (EVar ext (d2 t) IZ))) $ - weakenExpr (WCopy WSink) e2) - - EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" - {- - EIdx1 _ e ei - -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) - <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil - , STArr (SS n) eltty <- typeOf e -> - Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1) - `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ))) - (SEYes (SENo subtape)) - (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)) - (weakenExpr (WSink .> WSink) ei1)) - sub - (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) - (EVar ext (STArr n (d2 eltty)) (IS IZ))) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - -} - - EIdx _ e ei - -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) - <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil - , STArr n eltty <- typeOf e - , Refl <- indexTupD1Id n - , let tIxN = tTup (sreplicate n tIx) -> - Ret (binds `BPush` (STArr n (d1 eltty), e1) - `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) - `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) - (SEYes (SEYes (SENo subtape))) - (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) - sub - (ELet ext (EOneHot ext (STArr n eltty) (SAPArrIdx SAPHere n) - (EPair ext (EPair ext (EVar ext tIxN (IS IZ)) (EVar ext tIxN (IS (IS IZ)))) - (ENil ext)) - (EVar ext (d2 eltty) IZ)) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) - - EShape _ e - -- Allowed to ignore e2 here because the output of EShape is discrete, - -- hence we'd be passing a zero cotangent to e2 anyway. - | Ret e0 subtape e1 _ _ <- drev des accumMap e - , STArr n _ <- typeOf e - , Refl <- indexTupD1Id n -> - Ret e0 - subtape - (EShape ext e1) - (subenvNone (select SMerge des)) - (ENil ext) - - ESum1Inner _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , STArr (SS n) t <- typeOf e -> - Ret (e0 `BPush` (STArr (SS n) t, e1) - `BPush` (tTup (sreplicate (SS n) tIx), EShape ext (EVar ext (STArr (SS n) t) IZ))) - (SEYes (SENo subtape)) - (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext (EReplicate1Inner ext - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS (IS IZ)))) - (EVar ext (STArr n (d2 t)) IZ))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) - (EVar ext (d2 (STArr n t)) IZ)) - - EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e - EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e - - -- These should be the next to be implemented, I think - EFold1Inner{} -> err_unsupported "EFold1Inner" - - ENothing{} -> err_unsupported "ENothing" - EJust{} -> err_unsupported "EJust" - EMaybe{} -> err_unsupported "EMaybe" - - EWith{} -> err_accum - EAccum{} -> err_accum - EZero{} -> err_monoid - EPlus{} -> err_monoid - EOneHot{} -> err_monoid - - where - err_accum = error "Accumulator operations unsupported in the source program" - err_monoid = error "Monoid operations unsupported in the source program" - err_unsupported s = error $ "CHAD: unsupported " ++ s - - deriv_extremum :: ScalIsNumeric t' ~ True - => (forall env'. Ex env' (TArr (S n) (TScal t')) -> Ex env' (TArr n (TScal t'))) - -> Expr ValId env (TArr (S n) (TScal t')) -> Ret env sto (TArr n (TScal t')) - deriv_extremum extremum e - | Ret e0 subtape e1 sub e2 <- drev des accumMap e - , at@(STArr (SS n) t@(STScal st)) <- typeOf e - , let at' = STArr n t - , let tIxN = tTup (sreplicate (SS n) tIx) = - Ret (e0 `BPush` (at, e1) - `BPush` (at', extremum (EVar ext at IZ))) - (SEYes (SEYes subtape)) - (EVar ext at' IZ) - sub - (EMaybe ext - (zeroTup (subList (select SMerge des) sub)) - (ELet ext (EJust ext - (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS (IS IZ))))) $ - eif (EOp ext (OEq st) (EPair ext - (EIdx ext (EVar ext at (IS (IS (IS (IS IZ))))) (EVar ext tIxN IZ)) - (EIdx ext (EVar ext at' (IS (IS (IS IZ)))) (EFst ext (EVar ext tIxN IZ))))) - (EIdx ext (EVar ext (STArr n (d2 t)) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) - (EZero ext t))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) - (EVar ext (d2 at') IZ)) - -data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) - -data RetScoped env0 sto a s t = - forall shbinds tapebinds env0Merge. - RetScoped - (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds - (Subenv shbinds tapebinds) - (Ex (Append shbinds (D1E (a : env0))) (D1 t)) - (Subenv (Select env0 sto "merge") env0Merge) - -- ^ merge contributions to the _enclosing_ merge environment - (Ex (D2 t : Append tapebinds (D2AcE (Select env0 sto "accum"))) - (If (s == "discr") (Tup (D2E env0Merge)) - (TPair (Tup (D2E env0Merge)) (D2 a)))) - -- ^ the merge contributions, plus the cotangent to the argument - -- (if there is any) -deriving instance Show (RetScoped env0 sto a s t) - -drevScoped :: forall a s env sto t. - (?config :: CHADConfig) - => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) - -> STy a -> Storage s -> Maybe (ValId a) - -> Expr ValId (a : env) t - -> RetScoped env sto a s t -drevScoped des accumMap argty argsto argids expr = case argsto of - SMerge - | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> - case sub of - SEYes sub' -> RetScoped e0 subtape e1 sub' e2 - SENo sub' -> RetScoped e0 subtape e1 sub' (EPair ext e2 (EZero ext argty)) - - SAccum - | Just (VIArr i _) <- argids - , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap - , Just Refl <- testEquality foundTy (STAccum argty) - , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) expr -> - RetScoped e0 subtape e1 sub $ - let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in - ELet ext (EVar ext (STAccum argty) (WSink .> wtapebinds @> idx)) $ - weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds e0) subtape) - &. #ac (auto1 @(TAccum a)) - &. #tl (d2ace (select SAccum des))) - (#d :++: #body :++: #ac :++: #tl) - (#ac :++: #d :++: #body :++: #tl)) - -- Our contribution to the binding's cotangent _here_ is - -- zero, because we're contributing to an earlier binding - -- of the same value instead. - (EPair ext e2 (EZero ext argty)) - - | let accumMap' = case argids of - Just (VIArr i _) -> VarMap.insert i (STAccum argty) IZ (VarMap.sink1 accumMap) - _ -> VarMap.sink1 accumMap - , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' expr -> - RetScoped e0 subtape e1 sub $ - EWith ext argty (EZero ext argty) $ - weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #body (subList (bindingsBinds e0) subtape) - &. #ac (auto1 @(TAccum a)) - &. #tl (d2ace (select SAccum des))) - (#d :++: #body :++: #ac :++: #tl) - (#ac :++: #d :++: #body :++: #tl)) - e2 - - SDiscr - | Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap expr -> - RetScoped e0 subtape e1 sub e2 diff --git a/src/CHAD/APIv1.hs b/src/CHAD/APIv1.hs new file mode 100644 index 0000000..73d1580 --- /dev/null +++ b/src/CHAD/APIv1.hs @@ -0,0 +1,178 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.APIv1 ( + -- * Expressions and types + Ex, STy(..), SScalTy(..), Ty(..), ScalTy(..), + + -- * Reverse derivatives (Fast CHAD) + vjp, vjp', + D2, D2E, Tup, + CHADConfig(..), + + -- ** Primal type transform + -- | The primal type transform is only important when working with special + -- operations like 'CHAD.Language.custom'. + D1, + + -- * Forward derivatives (dual numbers) + jvp, jvpDN, + Tan, TanS, DN, DNS, DNE, + + -- * Working with expressions + interpret, interpret1, + compile, compile1, + fullSimplify, + SList(..), Value(..), Rep, + KnownEnv(..), KnownTy(..), + SNat(..), +) where + +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.UnMonoid +import CHAD.Compile qualified as Compile +import CHAD.Data +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.ForwardAD +import CHAD.ForwardAD.DualNumbers +import CHAD.Interpreter qualified as Interpreter +import CHAD.Simplify +import CHAD.Interpreter.Rep + + +-- | Compute a reverse derivative: a vector-Jacobian product. The type has been +-- simplified with the assumption that 'D1' is the identity. +vjp :: KnownEnv env => Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +vjp = vjp' (chcSetAccum defaultConfig) + +-- | Same as 'vjp', but supply CHAD configuration. +vjp' :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +vjp' config term + | Dict <- styKnown (d2 (typeOf term)) = + fullSimplify $ + unMonoid . simplifyFix $ -- need to merge onehots and accums for unMonoid to do its work + chad' config knownEnv (simplifyFix term) + +jvpDN :: Ex env t -> Ex (DNE env) (DN t) +jvpDN = dfwdDN + +jvp :: forall s t. KnownTy s => Ex '[s] t -> Ex '[Tan s, s] (TPair t (Tan t)) +jvp term + | Dict <- styKnown (tanty (knownTy @s)) + = fullSimplify $ + elet (ezipDN knownTy) $ + elet (weakenExpr (WCopy WClosed) (jvpDN term)) $ + eunzipDN (typeOf term) + where + ezipDN :: forall env s'. STy s' -> Ex (Tan s' : s' : env) (DN s') + ezipDN STNil = ENil ext + ezipDN (STPair a b) = + EPair ext (subst (\_ t' -> \case IZ -> EFst ext (EVar ext (STPair (tanty a) (tanty b)) IZ) + IS IZ -> EFst ext (EVar ext (STPair a b) (IS IZ)) + IS (IS i) -> EVar ext t' (IS (IS i))) + (ezipDN @env a)) + (subst (\_ t' -> \case IZ -> ESnd ext (EVar ext (STPair (tanty a) (tanty b)) IZ) + IS IZ -> ESnd ext (EVar ext (STPair a b) (IS IZ)) + IS (IS i) -> EVar ext t' (IS (IS i))) + (ezipDN @env b)) + ezipDN (STEither a b) = + ecase (EVar ext (STEither a b) (IS IZ)) + (ecase (EVar ext (STEither (tanty a) (tanty b)) (IS IZ)) + (EInl ext (dn b) (ezipDN a)) + (EError ext (STEither (dn a) (dn b)) "jvp zip: either mismatch lr")) + (ecase (EVar ext (STEither (tanty a) (tanty b)) (IS IZ)) + (EError ext (STEither (dn a) (dn b)) "jvp zip: either mismatch rl") + (EInr ext (dn a) (ezipDN b))) + ezipDN (STLEither a b) = + elcase (EVar ext (STLEither a b) (IS IZ)) + (ELNil ext (dn a) (dn b)) + (elcase (EVar ext (STLEither (tanty a) (tanty b)) (IS IZ)) + (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch lN") + (ELInl ext (dn b) (ezipDN a)) + (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch lr")) + (elcase (EVar ext (STLEither (tanty a) (tanty b)) (IS IZ)) + (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch rN") + (EError ext (STLEither (dn a) (dn b)) "jvp zip: leither mismatch rl") + (ELInr ext (dn a) (ezipDN b))) + ezipDN (STMaybe t) = + emaybe (EVar ext (STMaybe t) (IS IZ)) + (ENothing ext (dn t)) + (emaybe (EVar ext (STMaybe (tanty t)) (IS IZ)) + (EError ext (STMaybe (dn t)) "jvp zip: maybe mismatch jN") + (EJust ext (ezipDN t))) + ezipDN (STArr n t) = + ezipWith (ezipDN t) + (EVar ext (STArr n t) (IS IZ)) (EVar ext (STArr n (tanty t)) IZ) + ezipDN (STScal st) = case st of + STF32 -> EPair ext (EVar ext (STScal STF32) (IS IZ)) (EVar ext (tanty (STScal STF32)) IZ) + STF64 -> EPair ext (EVar ext (STScal STF64) (IS IZ)) (EVar ext (tanty (STScal STF64)) IZ) + STI32 -> EVar ext (STScal STI32) (IS IZ) + STI64 -> EVar ext (STScal STI64) (IS IZ) + STBool -> EVar ext (STScal STBool) (IS IZ) + ezipDN STAccum{} = error "jvp: Accumulators not supported in source program" + + eunzipDN :: forall env t'. STy t' -> Ex (DN t' : env) (TPair t' (Tan t')) + eunzipDN STNil = EPair ext (ENil ext) (ENil ext) + eunzipDN (STPair a b) = + eunPair (subst0 (EFst ext (EVar ext (STPair (dn a) (dn b)) IZ)) (eunzipDN a)) $ \w1 ea1 ea2 -> + eunPair (weakenExpr w1 (subst0 (ESnd ext (EVar ext (STPair (dn a) (dn b)) IZ)) (eunzipDN b))) $ \w2 eb1 eb2 -> + EPair ext (EPair ext (weakenExpr w2 ea1) eb1) (EPair ext (weakenExpr w2 ea2) eb2) + eunzipDN (STEither a b) = + ecase (EVar ext (STEither (dn a) (dn b)) IZ) + (eunPair (eunzipDN a) $ \_ a1 a2 -> + EPair ext (EInl ext b a1) (EInl ext (tanty b) a2)) + (eunPair (eunzipDN b) $ \_ b1 b2 -> + EPair ext (EInr ext a b1) (EInr ext (tanty a) b2)) + eunzipDN (STLEither a b) = + elcase (EVar ext (STLEither (dn a) (dn b)) IZ) + (EPair ext (ELNil ext a b) (ELNil ext (tanty a) (tanty b))) + (eunPair (eunzipDN a) $ \_ a1 a2 -> + EPair ext (ELInl ext b a1) (ELInl ext (tanty b) a2)) + (eunPair (eunzipDN b) $ \_ b1 b2 -> + EPair ext (ELInr ext a b1) (ELInr ext (tanty a) b2)) + eunzipDN (STMaybe t) = + emaybe (EVar ext (STMaybe (dn t)) IZ) + (EPair ext (ENothing ext t) (ENothing ext (tanty t))) + (eunPair (eunzipDN t) $ \_ e1 e2 -> + EPair ext (EJust ext e1) (EJust ext e2)) + eunzipDN (STArr n t) = + elet (emap (eunzipDN t) (EVar ext (STArr n (dn t)) IZ)) $ + EPair ext (emap (EFst ext (evar IZ)) (evar IZ)) + (emap (ESnd ext (evar IZ)) (evar IZ)) + eunzipDN (STScal st) = case st of + STF32 -> EVar ext (STPair (STScal STF32) (STScal STF32)) IZ + STF64 -> EVar ext (STPair (STScal STF64) (STScal STF64)) IZ + STI32 -> EPair ext (EVar ext (STScal STI32) IZ) (ENil ext) + STI64 -> EPair ext (EVar ext (STScal STI64) IZ) (ENil ext) + STBool -> EPair ext (EVar ext (STScal STBool) IZ) (ENil ext) + eunzipDN STAccum{} = error "jvp: Accumulators not supported in source program" + +-- | Interpret an expression in a given environment. +interpret :: KnownEnv env => SList Value env -> Ex env t -> Rep t +interpret = Interpreter.interpretOpen False knownEnv + +-- | Special case of 'interpret' for an expression with a single free variable. +interpret1 :: KnownTy s => Rep s -> Ex '[s] t -> Rep t +interpret1 x = interpret (Value x `SCons` SNil) + +-- | Compile an expression to C, load the resulting shared object into the +-- program and wrap it in a Haskell function. +compile :: KnownEnv env => Ex env t -> IO (SList Value env -> IO (Rep t)) +compile = Compile.compileStderr knownEnv + +-- | Special case of 'compile' for an expression with a single free variable. +compile1 :: KnownTy s => Ex '[s] t -> IO (Rep s -> IO (Rep t)) +compile1 term = do + f <- Compile.compileStderr knownEnv term + return (\x -> f (Value x `SCons` SNil)) + +-- | Simplify an expression. The 'vjp'/'jvp' functions already do this automatically. +fullSimplify :: KnownEnv env => Ex env t -> Ex env t +fullSimplify = simplifyFix . pruneExpr knownEnv . simplifyFix diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs new file mode 100644 index 0000000..ce9eb20 --- /dev/null +++ b/src/CHAD/AST.hs @@ -0,0 +1,709 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +module CHAD.AST (module CHAD.AST, module CHAD.AST.Types, module CHAD.AST.Accum, module CHAD.AST.Weaken) where + +import Data.Functor.Const +import Data.Functor.Identity +import Data.Int (Int64) +import Data.Kind (Type) + +import CHAD.Array +import CHAD.AST.Accum +import CHAD.AST.Sparse.Types +import CHAD.AST.Types +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types + + +-- General assumption: head of the list (whatever way it is associated) is the +-- inner variable / inner array dimension. In pretty printing, the inner +-- variable / inner dimension is printed on the _right_. +-- +-- All the monoid operations are unsupposed as the input to CHAD, and are +-- intended to be eliminated after simplification, so that the input program as +-- well as the output program do not contain these constructors. +-- TODO: ensure this by a "stage" type parameter. +type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type +data Expr x env t where + -- lambda calculus + EVar :: x t -> STy t -> Idx env t -> Expr x env t + ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t + + -- base types + EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b) + EFst :: x a -> Expr x env (TPair a b) -> Expr x env a + ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b + ENil :: x TNil -> Expr x env TNil + EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b) + EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b) + ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c + ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) + EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) + EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b + + -- array operations + EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) + EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) + EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t) + -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right) + EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) + ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) + EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) + EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) + EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) + EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) + EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) + EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b)) + + -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically: + -- an implementation is allowed to parallelise this thing and store the b + -- values in some implementation-defined order. + -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs. + EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative + -> Expr x (TPair t1 t1 : env) (TPair t1 b) + -> Expr x env t1 + -> Expr x env (TArr (S n) t1) + -> Expr x env (TPair (TArr n t1) -- normal primal fold output + (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores) + -- Reverse derivative of EFold1Inner. The contributions to the initial + -- element are not yet added together here; we assume a later fusion system + -- does that for us. + EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative + -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) + -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1 + -> Expr x env (TArr n t2) -- incoming cotangent + -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array + + -- expression operations + EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) + EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t + EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) + EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t + EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) + EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t + + -- custom derivatives + -- 'b' is the part of the input of the operation that derivatives should + -- be backpropagated to; 'a' is the inactive part. The dual field of + -- ECustom does not allow a derivative to be generated for 'a', and hence + -- none is propagated. + -- No accumulators are allowed inside a, b and tape. This restriction is + -- currently not used very much, so could be relaxed in the future; be sure + -- to check this requirement whenever it is necessary for soundness! + ECustom :: x t -> STy a -> STy b -> STy tape + -> Expr x [b, a] t -- ^ regular operation + -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass + -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative + -> Expr x env a -> Expr x env b + -> Expr x env t + + -- fake halfway checkpointing + ERecompute :: x t -> Expr x env t -> Expr x env t + + -- accumulation effect on monoids + -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it + -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not + -- need to create any zeros. + EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) + -- The 'Sparse' here is eliminated to dense by UnMonoid. + EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil + + -- monoidal operations (to be desugared to regular operations after simplification) + EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t + EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t + EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t + EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t + + -- interface of abstract monoidal types + ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) + ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) + ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) + ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c + + -- partiality + EError :: x a -> STy a -> String -> Expr x env a +deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) + +-- | A (well-typed, well-scoped) expression using De Bruijn indices. The full +-- 'Expr' type is parametrised on an indexed type of "additional info" (@x@); +-- 'Ex' sets this to nothing. +-- +-- Construct expressions using the functions in "CHAD.Language". +-- +-- Use 'CHAD.AST.Pretty.pprintExpr' or 'CHAD.AST.Pretty.ppExpr' to inspect +-- expressions. +type Ex = Expr (Const ()) + +ext :: Const () a +ext = Const () + +data Commutative = Commut | Noncommut + deriving (Show) + +type SOp :: Ty -> Ty -> Type +data SOp a t where + OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a) + OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + ONot :: SOp (TScal TBool) (TScal TBool) + OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) + OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) + OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right + ORound64 :: SOp (TScal TF64) (TScal TI64) + OToFl64 :: SOp (TScal TI64) (TScal TF64) + ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) + OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) + OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) + OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) +deriving instance Show (SOp a t) + +opt1 :: SOp a t -> STy a +opt1 = \case + OAdd t -> STPair (STScal t) (STScal t) + OMul t -> STPair (STScal t) (STScal t) + ONeg t -> STScal t + OLt t -> STPair (STScal t) (STScal t) + OLe t -> STPair (STScal t) (STScal t) + OEq t -> STPair (STScal t) (STScal t) + ONot -> STScal STBool + OAnd -> STPair (STScal STBool) (STScal STBool) + OOr -> STPair (STScal STBool) (STScal STBool) + OIf -> STScal STBool + ORound64 -> STScal STF64 + OToFl64 -> STScal STI64 + ORecip t -> STScal t + OExp t -> STScal t + OLog t -> STScal t + OIDiv t -> STPair (STScal t) (STScal t) + OMod t -> STPair (STScal t) (STScal t) + +opt2 :: SOp a t -> STy t +opt2 = \case + OAdd t -> STScal t + OMul t -> STScal t + ONeg t -> STScal t + OLt _ -> STScal STBool + OLe _ -> STScal STBool + OEq _ -> STScal STBool + ONot -> STScal STBool + OAnd -> STScal STBool + OOr -> STScal STBool + OIf -> STEither STNil STNil + ORound64 -> STScal STI64 + OToFl64 -> STScal STF64 + ORecip t -> STScal t + OExp t -> STScal t + OLog t -> STScal t + OIDiv t -> STScal t + OMod t -> STScal t + +typeOf :: Expr x env t -> STy t +typeOf = \case + EVar _ t _ -> t + ELet _ _ e -> typeOf e + + EPair _ a b -> STPair (typeOf a) (typeOf b) + EFst _ e | STPair t _ <- typeOf e -> t + ESnd _ e | STPair _ t <- typeOf e -> t + ENil _ -> STNil + EInl _ t2 e -> STEither (typeOf e) t2 + EInr _ t1 e -> STEither t1 (typeOf e) + ECase _ _ a _ -> typeOf a + ENothing _ t -> STMaybe t + EJust _ e -> STMaybe (typeOf e) + EMaybe _ e _ _ -> typeOf e + ELNil _ t1 t2 -> STLEither t1 t2 + ELInl _ t2 e -> STLEither (typeOf e) t2 + ELInr _ t1 e -> STLEither t1 (typeOf e) + ELCase _ _ a _ _ -> typeOf a + + EConstArr _ n t _ -> STArr n (STScal t) + EBuild _ n _ e -> STArr n (typeOf e) + EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a) + EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t + ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t + EUnit _ e -> STArr SZ (typeOf e) + EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t + EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t + EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t + EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t + EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2) + + EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb) + EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2) + + EConst _ t _ -> STScal t + EIdx0 _ e | STArr _ t <- typeOf e -> t + EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t + EIdx _ e _ | STArr _ t <- typeOf e -> t + EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx) + EOp _ op _ -> opt2 op + + ECustom _ _ _ _ e _ _ _ _ -> typeOf e + ERecompute _ e -> typeOf e + + EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) + EAccum _ _ _ _ _ _ _ -> STNil + + EZero _ t _ -> fromSMTy t + EDeepZero _ t _ -> fromSMTy t + EPlus _ t _ _ -> fromSMTy t + EOneHot _ t _ _ _ -> fromSMTy t + + EError _ t _ -> t + +extOf :: Expr x env t -> x t +extOf = \case + EVar x _ _ -> x + ELet x _ _ -> x + EPair x _ _ -> x + EFst x _ -> x + ESnd x _ -> x + ENil x -> x + EInl x _ _ -> x + EInr x _ _ -> x + ECase x _ _ _ -> x + ENothing x _ -> x + EJust x _ -> x + EMaybe x _ _ _ -> x + ELNil x _ _ -> x + ELInl x _ _ -> x + ELInr x _ _ -> x + ELCase x _ _ _ _ -> x + EConstArr x _ _ _ -> x + EBuild x _ _ _ -> x + EMap x _ _ -> x + EFold1Inner x _ _ _ _ -> x + ESum1Inner x _ -> x + EUnit x _ -> x + EReplicate1Inner x _ _ -> x + EMaximum1Inner x _ -> x + EMinimum1Inner x _ -> x + EReshape x _ _ _ -> x + EZip x _ _ -> x + EFold1InnerD1 x _ _ _ _ -> x + EFold1InnerD2 x _ _ _ _ -> x + EConst x _ _ -> x + EIdx0 x _ -> x + EIdx1 x _ _ -> x + EIdx x _ _ -> x + EShape x _ -> x + EOp x _ _ -> x + ECustom x _ _ _ _ _ _ _ _ -> x + ERecompute x _ -> x + EWith x _ _ _ -> x + EAccum x _ _ _ _ _ _ -> x + EZero x _ _ -> x + EDeepZero x _ _ -> x + EPlus x _ _ _ -> x + EOneHot x _ _ _ _ -> x + EError x _ _ -> x + +mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t +mapExt f = runIdentity . travExt (Identity . f) + +{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-} +travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t) +travExt f = \case + EVar x t i -> EVar <$> f x <*> pure t <*> pure i + ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body + EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b + EFst x e -> EFst <$> f x <*> travExt f e + ESnd x e -> ESnd <$> f x <*> travExt f e + ENil x -> ENil <$> f x + EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e + EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e + ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b + ENothing x t -> ENothing <$> f x <*> pure t + EJust x e -> EJust <$> f x <*> travExt f e + EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e + ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2 + ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e + ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e + ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c + EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a + EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b + EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b + EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e + EUnit x e -> EUnit <$> f x <*> travExt f e + EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b + EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e + EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e + EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b + EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b + EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + EConst x t v -> EConst <$> f x <*> pure t <*> pure v + EIdx0 x e -> EIdx0 <$> f x <*> travExt f e + EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b + EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es + EShape x e -> EShape <$> f x <*> travExt f e + EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e + ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 + ERecompute x e -> ERecompute <$> f x <*> travExt f e + EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2 + EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3 + EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e + EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e + EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b + EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b + EError x t s -> EError <$> f x <*> pure t <*> pure s + +substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t +substInline repl = + subst $ \x t -> \case IZ -> repl + IS i -> EVar x t i + +subst0 :: Ex (b : env) a -> Ex (a : env) t -> Ex (b : env) t +subst0 repl = + subst $ \_ t -> \case IZ -> repl + IS i -> EVar ext t (IS i) + +subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a) + -> Expr x env t -> Expr x env' t +subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId + +subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a) + -> env' :> envOut + -> Expr x env t + -> Expr x envOut t +subst' f w = \case + EVar x t i -> f x t w i + ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body) + EPair x a b -> EPair x (subst' f w a) (subst' f w b) + EFst x e -> EFst x (subst' f w e) + ESnd x e -> ESnd x (subst' f w e) + ENil x -> ENil x + EInl x t e -> EInl x t (subst' f w e) + EInr x t e -> EInr x t (subst' f w e) + ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) + ENothing x t -> ENothing x t + EJust x e -> EJust x (subst' f w e) + EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e) + ELNil x t1 t2 -> ELNil x t1 t2 + ELInl x t e -> ELInl x t (subst' f w e) + ELInr x t e -> ELInr x t (subst' f w e) + ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c) + EConstArr x n t a -> EConstArr x n t a + EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) + EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b) + EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) + ESum1Inner x e -> ESum1Inner x (subst' f w e) + EUnit x e -> EUnit x (subst' f w e) + EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) + EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) + EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) + EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) + EZip x a b -> EZip x (subst' f w a) (subst' f w b) + EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) + EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) + EConst x t v -> EConst x t v + EIdx0 x e -> EIdx0 x (subst' f w e) + EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) + EIdx x e es -> EIdx x (subst' f w e) (subst' f w es) + EShape x e -> EShape x (subst' f w e) + EOp x op e -> EOp x op (subst' f w e) + ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) + ERecompute x e -> ERecompute x (subst' f w e) + EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) + EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3) + EZero x t e -> EZero x t (subst' f w e) + EDeepZero x t e -> EDeepZero x t (subst' f w e) + EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) + EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) + EError x t s -> EError x t s + where + sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) + -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t + sinkF f' x' t w' = \case + IZ -> EVar x' t (w' @> IZ) + IS i -> f' x' t (WPop w') i + +weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t +weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) + +class KnownScalTy t where knownScalTy :: SScalTy t +instance KnownScalTy TI32 where knownScalTy = STI32 +instance KnownScalTy TI64 where knownScalTy = STI64 +instance KnownScalTy TF32 where knownScalTy = STF32 +instance KnownScalTy TF64 where knownScalTy = STF64 +instance KnownScalTy TBool where knownScalTy = STBool + +class KnownTy t where knownTy :: STy t +instance KnownTy TNil where knownTy = STNil +instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy +instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy +instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy +instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy +instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy +instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy +instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy + +class KnownMTy t where knownMTy :: SMTy t +instance KnownMTy TNil where knownMTy = SMTNil +instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy +instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy +instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy +instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy +instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy + +class KnownEnv env where knownEnv :: SList STy env +instance KnownEnv '[] where knownEnv = SNil +instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv + +styKnown :: STy t -> Dict (KnownTy t) +styKnown STNil = Dict +styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STMaybe t) | Dict <- styKnown t = Dict +styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict +styKnown (STScal t) | Dict <- sscaltyKnown t = Dict +styKnown (STAccum t) | Dict <- smtyKnown t = Dict + +smtyKnown :: SMTy t -> Dict (KnownMTy t) +smtyKnown SMTNil = Dict +smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict +smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict +smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict +smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict +smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict + +sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) +sscaltyKnown STI32 = Dict +sscaltyKnown STI64 = Dict +sscaltyKnown STF32 = Dict +sscaltyKnown STF64 = Dict +sscaltyKnown STBool = Dict + +envKnown :: SList STy env -> Dict (KnownEnv env) +envKnown SNil = Dict +envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict + +cheapExpr :: Expr x env t -> Bool +cheapExpr = \case + EVar{} -> True + ENil{} -> True + EConst{} -> True + EFst _ e -> cheapExpr e + ESnd _ e -> cheapExpr e + EUnit _ e -> cheapExpr e + _ -> False + +eTup :: SList (Ex env) list -> Ex env (Tup list) +eTup = mkTup (ENil ext) (EPair ext) + +ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) +ebuildUp1 n sh size f = + EBuild ext (SS n) (EPair ext sh size) $ + let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ + in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) + (EFst ext arg) + +eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) +eidxEq SZ _ _ = EConst ext STBool True +eidxEq (SS SZ) a b = + EOp ext (OEq STI64) (EPair ext (ESnd ext a) (ESnd ext b)) +eidxEq (SS n) a b + | let ty = tTup (sreplicate (SS n) tIx) + = ELet ext a $ + ELet ext (weakenExpr WSink b) $ + EOp ext OAnd $ EPair ext + (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ))) + (ESnd ext (EVar ext ty IZ)))) + (eidxEq n (EFst ext (EVar ext ty (IS IZ))) + (EFst ext (EVar ext ty IZ))) + +emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b) +emap f arr + | STArr _ t <- typeOf arr + , Dict <- styKnown t + = EMap ext f arr + +ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) +ezipWith f arr1 arr2 + | STArr _ t1 <- typeOf arr1 + , STArr _ t2 <- typeOf arr2 + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ) + IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ) + IS (IS i) -> EVar ext t (IS i)) + f) + (EZip ext arr1 arr2) + +ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) +ezip = EZip ext + +eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a +eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) + +-- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty). +eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) +eshapeEmpty SZ _ = EConst ext STBool False +eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0)) +eshapeEmpty (SS n) e = + ELet ext e $ + EOp ext OAnd (EPair ext + (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)) + (EConst ext STI64 0))) + (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) + +eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx)) +eshapeConst ShNil = ENil ext +eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n)) + +eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx +eshapeProd SZ _ = EConst ext STI64 1 +eshapeProd (SS SZ) e = ESnd ext e +eshapeProd (SS n) e = + eunPair e $ \_ e1 e2 -> + EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2) + +eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t) +eflatten e = + let STArr n _ = typeOf e + in elet e $ + EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ) + +-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t) +-- ezeroD2 t ezi = EZero ext (d2M t) ezi + +-- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil +-- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea + +-- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t) +-- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev + +eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r +eunPair (EPair _ e1 e2) k = k WId e1 e2 +eunPair e k | cheapExpr e = k WId (EFst ext e) (ESnd ext e) +eunPair e k = + elet e $ + k WSink + (EFst ext (evar IZ)) + (ESnd ext (evar IZ)) + +efst :: Ex env (TPair a b) -> Ex env a +efst (EPair _ e1 _) = e1 +efst e = EFst ext e + +esnd :: Ex env (TPair a b) -> Ex env b +esnd (EPair _ _ e2) = e2 +esnd e = ESnd ext e + +elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b +elet rhs body + | Dict <- styKnown (typeOf rhs) + = if cheapExpr rhs + then substInline rhs body + else ELet ext rhs body + +-- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost) +use :: Ex env a -> Ex env b -> Ex env b +use a b = elet a $ weakenExpr WSink b + +emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b +emaybe e a b + | STMaybe t <- typeOf e + , Dict <- styKnown t + = EMaybe ext a b e + +ecase :: Ex env (TEither a b) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c +ecase e a b + | STEither t1 t2 <- typeOf e + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ECase ext e a b + +elcase :: Ex env (TLEither a b) -> ((KnownTy a, KnownTy b) => Ex env c) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c +elcase e a b c + | STLEither t1 t2 <- typeOf e + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ELCase ext e a b c + +evar :: KnownTy a => Idx env a -> Ex env a +evar = EVar ext knownTy + +makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) +makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) + where + -- invariant: expression argument is duplicable + go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) + go SMTNil _ = ENil ext + go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) + go SMTLEither{} _ = ENil ext + go SMTMaybe{} _ = ENil ext + go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e + go SMTScal{} _ = ENil ext + +splitSparsePair + :: -- given a sparsity + STy (TPair a b) -> Sparse (TPair a b) t' + -> (forall a' b'. + -- I give you back two sparsities for a and b + Sparse a a' -> Sparse b b' + -- furthermore, I tell you that either your t' is already this (a', b') pair... + -> Either + (t' :~: TPair a' b') + -- or I tell you how to construct a' and b' from t', given an actual t' + (forall r' env. + Idx env t' + -> (forall env'. + (forall c. Ex env' c -> Ex env c) + -> Ex env' a' -> Ex env' b' -> r') + -> r') + -> r) + -> r +splitSparsePair _ SpAbsent k = + k SpAbsent SpAbsent $ Right $ \_ k2 -> + k2 id (ENil ext) (ENil ext) +splitSparsePair _ (SpPair s1 s2) k1 = + k1 s1 s2 $ Left Refl +splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k = + let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in + k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 -> + k2 (elet $ + emaybe (EVar ext (STMaybe (applySparse s t)) i) + (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2))) + (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ))))) + (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ)) + +splitSparsePair _ (SpSparse SpAbsent) k = + k SpAbsent SpAbsent $ Right $ \_ k2 -> + k2 id (ENil ext) (ENil ext) +-- -- TODO: having to handle sparse-of-sparse at all is ridiculous +splitSparsePair t (SpSparse (SpSparse s)) k = + splitSparsePair t (SpSparse s) $ \s1 s2 eres -> + k s1 s2 $ Right $ \i k2 -> + case eres of + Left refl -> case refl of {} + Right f -> + f IZ $ \wrap e1 e2 -> + k2 (\body -> + elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i) + (ENothing ext (applySparse s t)) + (evar IZ)) $ + wrap body) + e1 e2 diff --git a/src/CHAD/AST/Accum.hs b/src/CHAD/AST/Accum.hs new file mode 100644 index 0000000..ea74a95 --- /dev/null +++ b/src/CHAD/AST/Accum.hs @@ -0,0 +1,137 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeData #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +module CHAD.AST.Accum where + +import CHAD.AST.Types +import CHAD.Data + + +data AcPrj + = APHere + | APFst AcPrj + | APSnd AcPrj + | APLeft AcPrj + | APRight AcPrj + | APJust AcPrj + | APArrIdx AcPrj + | APArrSlice Nat + +-- | @b@ is a small part of @a@, indicated by the projection @p@. +data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where + SAPHere :: SAcPrj APHere a a + SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b + SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b + SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b + SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b + SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b + SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b + -- TODO: + -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) +deriving instance Show (SAcPrj p a b) + +type data AIDense = AID | AIS + +data SAIDense d where + SAID :: SAIDense AID + SAIS :: SAIDense AIS +deriving instance Show (SAIDense d) + +type family AcIdx d p t where + AcIdx d APHere t = TNil + AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a + AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b + AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b) + AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b) + AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a + AcIdx d (APRight p) (TLEither a b) = AcIdx d p b + AcIdx d (APJust p) (TMaybe a) = AcIdx d p a + AcIdx AID (APArrIdx p) (TArr n a) = + -- (index, recursive info) + TPair (Tup (Replicate n TIx)) (AcIdx AID p a) + AcIdx AIS (APArrIdx p) (TArr n a) = + -- ((index, shape info), recursive info) + TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) + (AcIdx AIS p a) + -- AcIdx AID (APArrSlice m) (TArr n a) = + -- -- index + -- Tup (Replicate m TIx) + -- AcIdx AIS (APArrSlice m) (TArr n a) = + -- -- (index, array shape) + -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) + +type AcIdxD p t = AcIdx AID p t +type AcIdxS p t = AcIdx AIS p t + +acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b +acPrjTy SAPHere t = t +acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t +acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t +acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t +acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t +acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t +acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t + +type family ZeroInfo t where + ZeroInfo TNil = TNil + ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) + ZeroInfo (TLEither a b) = TNil + ZeroInfo (TMaybe a) = TNil + ZeroInfo (TArr n t) = TArr n (ZeroInfo t) + ZeroInfo (TScal t) = TNil + +tZeroInfo :: SMTy t -> STy (ZeroInfo t) +tZeroInfo SMTNil = STNil +tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) +tZeroInfo (SMTLEither _ _) = STNil +tZeroInfo (SMTMaybe _) = STNil +tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) +tZeroInfo (SMTScal _) = STNil + +-- | Info needed to create a zero-valued deep accumulator for a monoid type. +-- Should be constructable from a D1. +type family DeepZeroInfo t where + DeepZeroInfo TNil = TNil + DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) + DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) + DeepZeroInfo (TScal t) = TNil + +tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) +tDeepZeroInfo SMTNil = STNil +tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) +tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) +tDeepZeroInfo (SMTScal _) = STNil + +-- -- | Additional info needed for accumulation. This is empty unless there is +-- -- sparsity in the monoid. +-- type family AccumInfo t where +-- AccumInfo TNil = TNil +-- AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b) +-- AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) +-- AccumInfo (TMaybe a) = TMaybe (AccumInfo a) +-- AccumInfo (TArr n t) = TArr n (AccumInfo t) +-- AccumInfo (TScal t) = TNil + +-- type family PrimalInfo t where +-- PrimalInfo TNil = TNil +-- PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b) +-- PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) +-- PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a) +-- PrimalInfo (TArr n t) = TArr n (PrimalInfo t) +-- PrimalInfo (TScal t) = TNil + +-- tPrimalInfo :: SMTy t -> STy (PrimalInfo t) +-- tPrimalInfo SMTNil = STNil +-- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b) +-- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b) +-- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a) +-- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t) +-- tPrimalInfo (SMTScal _) = STNil diff --git a/src/AST/Bindings.hs b/src/CHAD/AST/Bindings.hs index 3d99afe..c1a1e77 100644 --- a/src/AST/Bindings.hs +++ b/src/CHAD/AST/Bindings.hs @@ -13,11 +13,12 @@ -- useful here. {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS -Wno-partial-type-signatures #-} -module AST.Bindings where +module CHAD.AST.Bindings where -import AST -import Data -import Lemmas +import CHAD.AST +import CHAD.AST.Env +import CHAD.Data +import CHAD.Lemmas -- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'. @@ -27,6 +28,10 @@ data Bindings f env binds where deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') infixl `BPush` +bpush :: Bindings (Expr x) env binds -> Expr x (Append binds env) t -> Bindings (Expr x) env (t : binds) +bpush b e = b `BPush` (typeOf e, e) +infixl `bpush` + mapBindings :: (forall env' t'. f env' t' -> g env' t') -> Bindings f env binds -> Bindings g env binds mapBindings _ BTop = BTop @@ -41,6 +46,11 @@ weakenBindings wf w (BPush b (t, x)) = let (b', w') = weakenBindings wf w b in (BPush b' (t, wf w' x), WCopy w') +weakenBindingsE :: env1 :> env2 + -> Bindings (Expr x) env1 binds + -> (Bindings (Expr x) env2 binds, Append binds env1 :> Append binds env2) +weakenBindingsE = weakenBindings weakenExpr + weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env' weakenOver SNil w = w weakenOver (SCons _ ts) w = WCopy (weakenOver ts w) @@ -62,3 +72,13 @@ bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds) letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t letBinds BTop = id letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs + +collectBindings :: SList STy env -> Subenv env env' -> Bindings Ex env env' +collectBindings = \env -> fst . go env WId + where + go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0) + go _ _ SETop = (BTop, WId) + go (ty `SCons` env) w (SEYesR sub) = + let (bs, w') = go env (WPop w) sub + in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w') + go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub diff --git a/src/CHAD/AST/Count.hs b/src/CHAD/AST/Count.hs new file mode 100644 index 0000000..46173d2 --- /dev/null +++ b/src/CHAD/AST/Count.hs @@ -0,0 +1,927 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +module CHAD.AST.Count where + +import Data.Functor.Product +import Data.Some +import Data.Type.Equality +import GHC.Generics (Generic, Generically(..)) + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Env +import CHAD.Data + + +-- | The monoid operation combines assuming that /both/ branches are taken. +class Monoid a => Occurrence a where + -- | One of the two branches is taken + (<||>) :: a -> a -> a + -- | This code is executed many times + scaleMany :: a -> a + + +data Count = Zero | One | Many + deriving (Show, Eq, Ord) + +instance Semigroup Count where + Zero <> n = n + n <> Zero = n + _ <> _ = Many +instance Monoid Count where + mempty = Zero +instance Occurrence Count where + (<||>) = max + scaleMany Zero = Zero + scaleMany _ = Many + +data Occ = Occ { _occLexical :: Count + , _occRuntime :: Count } + deriving (Eq, Generic) + deriving (Semigroup, Monoid) via Generically Occ + +instance Show Occ where + showsPrec d (Occ l r) = showParen (d > 10) $ + showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r + +instance Occurrence Occ where + Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (r1 <||> r2) + scaleMany (Occ l c) = Occ l (scaleMany c) + + +data Substruc t t' where + -- If you add constructors here, do not forget to update the COMPLETE pragmas of any pattern synonyms below + SsFull :: Substruc t t + SsNone :: Substruc t TNil + SsPair :: Substruc a a' -> Substruc b b' -> Substruc (TPair a b) (TPair a' b') + SsEither :: Substruc a a' -> Substruc b b' -> Substruc (TEither a b) (TEither a' b') + SsLEither :: Substruc a a' -> Substruc b b' -> Substruc (TLEither a b) (TLEither a' b') + SsMaybe :: Substruc a a' -> Substruc (TMaybe a) (TMaybe a') + SsArr :: Substruc a a' -> Substruc (TArr n a) (TArr n a') -- ^ union of usages of all array elements + SsAccum :: Substruc a a' -> Substruc (TAccum a) (TAccum a') + +pattern SsPair' :: forall a b t'. forall a' b'. t' ~ TPair a' b' => Substruc a a' -> Substruc b b' -> Substruc (TPair a b) t' +pattern SsPair' s1 s2 <- ((\case { SsFull -> SsPair SsFull SsFull ; s -> s }) -> SsPair s1 s2) + where SsPair' = SsPair +{-# COMPLETE SsNone, SsPair', SsEither, SsLEither, SsMaybe, SsArr, SsAccum #-} + +pattern SsArr' :: forall n a t'. forall a'. t' ~ TArr n a' => Substruc a a' -> Substruc (TArr n a) t' +pattern SsArr' s <- ((\case { SsFull -> SsArr SsFull ; s -> s }) -> SsArr s) + where SsArr' = SsArr +{-# COMPLETE SsNone, SsPair, SsEither, SsLEither, SsMaybe, SsArr', SsAccum #-} + +instance Semigroup (Some (Substruc t)) where + Some SsFull <> _ = Some SsFull + _ <> Some SsFull = Some SsFull + Some SsNone <> s = s + s <> Some SsNone = s + Some (SsPair a b) <> Some (SsPair a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsPair a2 b2) + Some (SsEither a b) <> Some (SsEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsEither a2 b2) + Some (SsLEither a b) <> Some (SsLEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsLEither a2 b2) + Some (SsMaybe a) <> Some (SsMaybe a') = withSome (Some a <> Some a') $ \a2 -> Some (SsMaybe a2) + Some (SsArr a) <> Some (SsArr a') = withSome (Some a <> Some a') $ \a2 -> Some (SsArr a2) + Some (SsAccum a) <> Some (SsAccum a') = withSome (Some a <> Some a') $ \a2 -> Some (SsAccum a2) +instance Monoid (Some (Substruc t)) where + mempty = Some SsNone + +instance TestEquality (Substruc t) where + testEquality SsFull s = isFull s + testEquality s SsFull = sym <$> isFull s + testEquality SsNone SsNone = Just Refl + testEquality SsNone _ = Nothing + testEquality _ SsNone = Nothing + testEquality (SsPair a b) (SsPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsEither a b) (SsEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsLEither a b) (SsLEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsMaybe s) (SsMaybe s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + testEquality (SsArr s) (SsArr s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + testEquality (SsAccum s) (SsAccum s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + +isFull :: Substruc t t' -> Maybe (t :~: t') +isFull SsFull = Just Refl +isFull SsNone = Nothing -- TODO: nil? +isFull (SsPair a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsLEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsMaybe s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing +isFull (SsArr s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing +isFull (SsAccum s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing + +applySubstruc :: Substruc t t' -> STy t -> STy t' +applySubstruc SsFull t = t +applySubstruc SsNone _ = STNil +applySubstruc (SsPair s1 s2) (STPair a b) = STPair (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsEither s1 s2) (STEither a b) = STEither (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsLEither s1 s2) (STLEither a b) = STLEither (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsMaybe s) (STMaybe t) = STMaybe (applySubstruc s t) +applySubstruc (SsArr s) (STArr n t) = STArr n (applySubstruc s t) +applySubstruc (SsAccum s) (STAccum t) = STAccum (applySubstrucM s t) + +applySubstrucM :: Substruc t t' -> SMTy t -> SMTy t' +applySubstrucM SsFull t = t +applySubstrucM SsNone _ = SMTNil +applySubstrucM (SsPair s1 s2) (SMTPair a b) = SMTPair (applySubstrucM s1 a) (applySubstrucM s2 b) +applySubstrucM (SsLEither s1 s2) (SMTLEither a b) = SMTLEither (applySubstrucM s1 a) (applySubstrucM s2 b) +applySubstrucM (SsMaybe s) (SMTMaybe t) = SMTMaybe (applySubstrucM s t) +applySubstrucM (SsArr s) (SMTArr n t) = SMTArr n (applySubstrucM s t) +applySubstrucM _ t = case t of {} + +data ExMap a b = ExMap (forall env. Ex env a -> Ex env b) + | a ~ b => ExMapId + +fromExMap :: ExMap a b -> Ex env a -> Ex env b +fromExMap (ExMap f) = f +fromExMap ExMapId = id + +simplifySubstruc :: STy t -> Substruc t t' -> Substruc t t' +simplifySubstruc STNil SsNone = SsFull + +simplifySubstruc _ SsFull = SsFull +simplifySubstruc _ SsNone = SsNone +simplifySubstruc (STPair t1 t2) (SsPair s1 s2) = SsPair (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STEither t1 t2) (SsEither s1 s2) = SsEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STLEither t1 t2) (SsLEither s1 s2) = SsLEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STMaybe t) (SsMaybe s) = SsMaybe (simplifySubstruc t s) +simplifySubstruc (STArr _ t) (SsArr s) = SsArr (simplifySubstruc t s) +simplifySubstruc (STAccum t) (SsAccum s) = SsAccum (simplifySubstruc (fromSMTy t) s) + +-- simplifySubstruc' :: Substruc t t' +-- -> (forall t'2. Substruc t t'2 -> ExMap t'2 t' -> r) -> r +-- simplifySubstruc' SsFull k = k SsFull ExMapId +-- simplifySubstruc' SsNone k = k SsNone ExMapId +-- simplifySubstruc' (SsPair s1 s2) k = +-- simplifySubstruc' s1 $ \s1' f1 -> +-- simplifySubstruc' s2 $ \s2' f2 -> +-- case (s1', s2') of +-- (SsFull, SsFull) -> +-- k SsFull (case (f1, f2) of +-- (ExMapId, ExMapId) -> ExMapId +-- _ -> ExMap (\e -> eunPair e $ \_ e1 e2 -> +-- EPair ext (fromExMap f1 e1) (fromExMap f2 e2))) +-- (SsNone, SsNone) -> k SsNone (ExMap (\_ -> EPair ext (fromExMap f1 (ENil ext)) (fromExMap f2 (ENil ext)))) +-- _ -> k (SsPair s1' s2') (ExMap (\e -> elet e $ EPair ext (fromExMap f1 (EFst ext (evar IZ))) (fromExMap f2 (ESnd ext (evar IZ))))) +-- simplifySubstruc' _ _ = _ + +-- ssUnpair :: Substruc (TPair a b) -> (Substruc a, Substruc b) +-- ssUnpair SsFull = (SsFull, SsFull) +-- ssUnpair SsNone = (SsNone, SsNone) +-- ssUnpair (SsPair a b) = (a, b) + +-- ssUnleft :: Substruc (TEither a b) -> Substruc a +-- ssUnleft SsFull = SsFull +-- ssUnleft SsNone = SsNone +-- ssUnleft (SsEither a _) = a + +-- ssUnright :: Substruc (TEither a b) -> Substruc b +-- ssUnright SsFull = SsFull +-- ssUnright SsNone = SsNone +-- ssUnright (SsEither _ b) = b + +-- ssUnlleft :: Substruc (TLEither a b) -> Substruc a +-- ssUnlleft SsFull = SsFull +-- ssUnlleft SsNone = SsNone +-- ssUnlleft (SsLEither a _) = a + +-- ssUnlright :: Substruc (TLEither a b) -> Substruc b +-- ssUnlright SsFull = SsFull +-- ssUnlright SsNone = SsNone +-- ssUnlright (SsLEither _ b) = b + +-- ssUnjust :: Substruc (TMaybe a) -> Substruc a +-- ssUnjust SsFull = SsFull +-- ssUnjust SsNone = SsNone +-- ssUnjust (SsMaybe a) = a + +-- ssUnarr :: Substruc (TArr n a) -> Substruc a +-- ssUnarr SsFull = SsFull +-- ssUnarr SsNone = SsNone +-- ssUnarr (SsArr a) = a + +-- ssUnaccum :: Substruc (TAccum a) -> Substruc a +-- ssUnaccum SsFull = SsFull +-- ssUnaccum SsNone = SsNone +-- ssUnaccum (SsAccum a) = a + + +type family MapEmpty env where + MapEmpty '[] = '[] + MapEmpty (t : env) = TNil : MapEmpty env + +data OccEnv a env env' where + OccEnd :: OccEnv a env (MapEmpty env) -- not necessarily top! + OccPush :: OccEnv a env env' -> a -> Substruc t t' -> OccEnv a (t : env) (t' : env') + +instance Semigroup a => Semigroup (Some (OccEnv a env)) where + Some OccEnd <> e = e + e <> Some OccEnd = e + Some (OccPush e o s) <> Some (OccPush e' o' s') = withSome (Some e <> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <> o') s2) + +instance Semigroup a => Monoid (Some (OccEnv a env)) where + mempty = Some OccEnd + +instance Occurrence a => Occurrence (Some (OccEnv a env)) where + Some OccEnd <||> e = e + e <||> Some OccEnd = e + Some (OccPush e o s) <||> Some (OccPush e' o' s') = withSome (Some e <||> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <||> o') s2) + + scaleMany (Some OccEnd) = Some OccEnd + scaleMany (Some (OccPush e o s)) = withSome (scaleMany (Some e)) $ \e2 -> Some (OccPush e2 (scaleMany o) s) + +onehotOccEnv :: Monoid a => Idx env t -> a -> Substruc t t' -> Some (OccEnv a env) +onehotOccEnv IZ v s = Some (OccPush OccEnd v s) +onehotOccEnv (IS i) v s + | Some env' <- onehotOccEnv i v s + = Some (OccPush env' mempty SsNone) + +occEnvPop :: OccEnv a (t : env) (t' : env') -> (OccEnv a env env', Substruc t t') +occEnvPop (OccPush e _ s) = (e, s) +occEnvPop OccEnd = (OccEnd, SsNone) + +occEnvPop' :: OccEnv a (t : env) env' -> (forall t' env''. env' ~ t' : env'' => OccEnv a env env'' -> Substruc t t' -> r) -> r +occEnvPop' (OccPush e _ s) k = k e s +occEnvPop' OccEnd k = k OccEnd SsNone + +occEnvPopSome :: Some (OccEnv a (t : env)) -> Some (OccEnv a env) +occEnvPopSome (Some (OccPush e _ _)) = Some e +occEnvPopSome (Some OccEnd) = Some OccEnd + +occEnvPrj :: Monoid a => OccEnv a env env' -> Idx env t -> (a, Some (Substruc t)) +occEnvPrj OccEnd _ = mempty +occEnvPrj (OccPush _ o s) IZ = (o, Some s) +occEnvPrj (OccPush e _ _) (IS i) = occEnvPrj e i + +occEnvPrjS :: OccEnv a env env' -> Idx env t -> Some (Product (Substruc t) (Idx env')) +occEnvPrjS OccEnd IZ = Some (Pair SsNone IZ) +occEnvPrjS OccEnd (IS i) | Some (Pair s i') <- occEnvPrjS OccEnd i = Some (Pair s (IS i')) +occEnvPrjS (OccPush _ _ s) IZ = Some (Pair s IZ) +occEnvPrjS (OccPush e _ _) (IS i) + | Some (Pair s' i') <- occEnvPrjS e i + = Some (Pair s' (IS i')) + +projectSmallerSubstruc :: Substruc t t'big -> Substruc t t'small -> Ex env t'big -> Ex env t'small +projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of + _ | Just Refl <- testEquality topsbig topssmall -> ex + + (SsFull, SsFull) -> ex + (SsNone, SsNone) -> ex + (SsNone, _) -> error "projectSmallerSubstruc: smaller substructure not smaller" + (_, SsNone) -> + case typeOf ex of + STNil -> ex + _ -> use ex $ ENil ext + + (SsPair s1 s2, SsPair s1' s2') -> + eunPair ex $ \_ e1 e2 -> + EPair ext (projectSmallerSubstruc s1 s1' e1) (projectSmallerSubstruc s2 s2' e2) + (s@SsPair{}, SsFull) -> projectSmallerSubstruc s (SsPair SsFull SsFull) ex + (SsFull, s@SsPair{}) -> projectSmallerSubstruc (SsPair SsFull SsFull) s ex + + (SsEither s1 s2, SsEither s1' s2') + | STEither t1 t2 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) + in ecase ex + (EInl ext (typeOf e2) e1) + (EInr ext (typeOf e1) e2) + (s@SsEither{}, SsFull) -> projectSmallerSubstruc s (SsEither SsFull SsFull) ex + (SsFull, s@SsEither{}) -> projectSmallerSubstruc (SsEither SsFull SsFull) s ex + + (SsLEither s1 s2, SsLEither s1' s2') + | STLEither t1 t2 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) + in elcase ex + (ELNil ext (typeOf e1) (typeOf e2)) + (ELInl ext (typeOf e2) e1) + (ELInr ext (typeOf e1) e2) + (s@SsLEither{}, SsFull) -> projectSmallerSubstruc s (SsLEither SsFull SsFull) ex + (SsFull, s@SsLEither{}) -> projectSmallerSubstruc (SsLEither SsFull SsFull) s ex + + (SsMaybe s1, SsMaybe s1') + | STMaybe t1 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + in emaybe ex + (ENothing ext (typeOf e1)) + (EJust ext e1) + (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex + (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex + + (SsArr s1, SsArr s2) -> emap (projectSmallerSubstruc s1 s2 (evar IZ)) ex + (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex + (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex + + (SsAccum _, SsAccum _) -> error "TODO smaller ssaccum" + (s@SsAccum{}, SsFull) -> projectSmallerSubstruc s (SsAccum SsFull) ex + (SsFull, s@SsAccum{}) -> projectSmallerSubstruc (SsAccum SsFull) s ex + + +-- | A boolean for each entry in the environment, with the ability to uniformly +-- mask the top part above a certain index. +data EnvMask env where + EMRest :: Bool -> EnvMask env + EMPush :: EnvMask env -> Bool -> EnvMask (t : env) + +envMaskPrj :: EnvMask env -> Idx env t -> Bool +envMaskPrj (EMRest b) _ = b +envMaskPrj (_ `EMPush` b) IZ = b +envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i + +occCount :: Idx env a -> Expr x env t -> Occ +occCount idx ex + | Some env <- occCountAll ex + = fst (occEnvPrj env idx) + +occCountAll :: Expr x env t -> Some (OccEnv Occ env) +occCountAll ex = occCountX SsFull ex $ \env _ -> Some env + +pruneExpr :: SList f env -> Expr x env t -> Ex env t +pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env) + where + fullOccEnv :: SList f env -> OccEnv () env env + fullOccEnv SNil = OccEnd + fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull + +-- In one traversal, count occurrences of variables and determine what parts of +-- expressions are actually used. These two results are computed independently: +-- even if (almost) nothing of a particular term is actually used, variable +-- references in that term still count as usual. +-- +-- In @occCountX s t k@: +-- * s: how much of the result of this term is required +-- * t: the term to analyse +-- * k: is passed the actual environment usage of this expression, including +-- occurrence counts. The callback reconstructs a new expression in an +-- updated "response" environment. The response must be at least as large as +-- the computed usages. +occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t + -> (forall env'. OccEnv Occ env env' + -- response OccEnv must be at least as large as the OccEnv returned above + -> (forall env''. OccEnv () env env'' -> Ex env'' t') + -> r) + -> r +occCountX initialS topexpr k = case topexpr of + EVar _ t i -> + withSome (onehotOccEnv i (Occ One One) s) $ \env -> + k env $ \env' -> + withSome (occEnvPrjS env' i) $ \(Pair s' i') -> + projectSmallerSubstruc s' s (EVar ext (applySubstruc s' t) i') + ELet _ rhs body -> + occCountX s body $ \envB mkbody -> + occEnvPop' envB $ \envB' s1 -> + occCountX s1 rhs $ \envR mkrhs -> + withSome (Some envB' <> Some envR) $ \env -> + k env $ \env' -> + ELet ext (mkrhs env') (mkbody (OccPush env' () s1)) + EPair _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsPair' s1 s2 -> + occCountX s1 a $ \env1 mka -> + occCountX s2 b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EPair ext (mka env') (mkb env') + EFst _ e -> + occCountX (SsPair s SsNone) e $ \env1 mke -> + k env1 $ \env' -> + EFst ext (mke env') + ESnd _ e -> + occCountX (SsPair SsNone s) e $ \env1 mke -> + k env1 $ \env' -> + ESnd ext (mke env') + ENil _ -> + case s of + SsFull -> k OccEnd (\_ -> ENil ext) + SsNone -> k OccEnd (\_ -> ENil ext) + EInl _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsEither s1 s2 -> + occCountX s1 e $ \env1 mke -> + k env1 $ \env' -> + EInl ext (applySubstruc s2 t) (mke env') + SsFull -> occCountX (SsEither SsFull SsFull) topexpr k + EInr _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsEither s1 s2 -> + occCountX s2 e $ \env1 mke -> + k env1 $ \env' -> + EInr ext (applySubstruc s1 t) (mke env') + SsFull -> occCountX (SsEither SsFull SsFull) topexpr k + ECase _ e a b -> + occCountX s a $ \env1' mka -> + occCountX s b $ \env2' mkb -> + occEnvPop' env1' $ \env1 s1 -> + occEnvPop' env2' $ \env2 s2 -> + occCountX (SsEither s1 s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> + k env $ \env' -> + ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2)) + ENothing _ t -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsMaybe s' -> k OccEnd (\_ -> ENothing ext (applySubstruc s' t)) + SsFull -> occCountX (SsMaybe SsFull) topexpr k + EJust _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsMaybe s' -> + occCountX s' e $ \env1 mke -> + k env1 $ \env' -> + EJust ext (mke env') + SsFull -> occCountX (SsMaybe SsFull) topexpr k + EMaybe _ a b e -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s2 -> + occCountX (SsMaybe s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> + k env $ \env' -> + EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env') + ELNil _ t1 t2 -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsLEither s1 s2 -> k OccEnd (\_ -> ELNil ext (applySubstruc s1 t1) (applySubstruc s2 t2)) + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELInl _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsLEither s1 s2 -> + occCountX s1 e $ \env1 mke -> + k env1 $ \env' -> + ELInl ext (applySubstruc s2 t) (mke env') + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELInr _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsLEither s1 s2 -> + occCountX s2 e $ \env1 mke -> + k env1 $ \env' -> + ELInr ext (applySubstruc s1 t) (mke env') + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELCase _ e a b c -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2' mkb -> + occCountX s c $ \env3' mkc -> + occEnvPop' env2' $ \env2 s1 -> + occEnvPop' env3' $ \env3 s2 -> + occCountX (SsLEither s1 s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env -> + k env $ \env' -> + ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2)) + + EConstArr _ n t x -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext)) + SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x) + + EBuild _ n a b -> + case s of + SsNone -> + occCountX SsFull a $ \env1 mka -> + occCountX SsNone b $ \env2'' mkb -> + occEnvPop' env2'' $ \env2' s2 -> + withSome (Some env1 <> scaleMany (Some env2')) $ \env -> + k env $ \env' -> + use (EBuild ext n (mka env') $ + use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $ + ENil ext) $ + ENil ext + SsArr' s' -> + occCountX SsFull a $ \env1 mka -> + occCountX s' b $ \env2'' mkb -> + occEnvPop' env2'' $ \env2' s2 -> + withSome (Some env1 <> scaleMany (Some env2')) $ \env -> + k env $ \env' -> + EBuild ext n (mka env') $ + elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) + + EMap _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1 -> + occCountX (SsArr s1) b $ \env2 mkb -> + withSome (scaleMany (Some env1') <> Some env2) $ \env -> + k env $ \env' -> + use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $ + ENil ext + SsArr' s' -> + occCountX s' a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1 -> + occCountX (SsArr s1) b $ \env2 mkb -> + withSome (scaleMany (Some env1') <> Some env2) $ \env -> + k env $ \env' -> + EMap ext (mka (OccPush env' () s1)) (mkb env') + + EFold1Inner _ commut a b c -> + occCountX SsFull a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1' -> + let s1 = case s1' of + SsNone -> Some SsNone + SsPair' s1'a s1'b -> Some s1'a <> Some s1'b + s0 = case s of + SsNone -> Some SsNone + SsArr' s' -> Some s' in + withSome (s1 <> s0) $ \sElt -> + occCountX sElt b $ \env2 mkb -> + occCountX (SsArr sElt) c $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsArr sElt) s $ + EFold1Inner ext commut + (projectSmallerSubstruc SsFull sElt $ + mka (OccPush env' () (SsPair sElt sElt))) + (mkb env') (mkc env') + + ESum1Inner _ e -> handleReduction (ESum1Inner ext) e + + EUnit _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env mke -> + k env $ \env' -> + use (mke env') $ ENil ext + SsArr' s' -> + occCountX s' e $ \env mke -> + k env $ \env' -> + EUnit ext (mke env') + + EReplicate1Inner _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' s' -> + occCountX SsFull a $ \env1 mka -> + occCountX (SsArr s') b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EReplicate1Inner ext (mka env') (mkb env') + + EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e + EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e + + EReshape _ n esh e -> + case s of + SsNone -> + occCountX SsNone esh $ \env1 mkesh -> + occCountX SsNone e $ \env2 mke -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkesh env') $ use (mke env') $ ENil ext + SsArr' s' -> + occCountX SsFull esh $ \env1 mkesh -> + occCountX (SsArr s') e $ \env2 mke -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EReshape ext n (mkesh env') (mke env') + + EZip _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' SsNone -> + occCountX (SsArr SsNone) a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkb env') $ mka env' + SsArr' (SsPair' SsNone s2) -> + occCountX SsNone a $ \env1 mka -> + occCountX (SsArr s2) b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ + emap (EPair ext (ENil ext) (evar IZ)) (mkb env') + SsArr' (SsPair' s1 SsNone) -> + occCountX (SsArr s1) a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkb env') $ + emap (EPair ext (evar IZ) (ENil ext)) (mka env') + SsArr' (SsPair' s1 s2) -> + occCountX (SsArr s1) a $ \env1 mka -> + occCountX (SsArr s2) b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EZip ext (mka env') (mkb env') + + EFold1InnerD1 _ cm e1 e2 e3 -> + case s of + -- If nothing is necessary, we can execute a fold and then proceed to ignore it + SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex + -- If we don't need the stores, still a fold suffices + SsPair' sP SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext) + -- If for whatever reason the additional stores themselves are + -- unnecessary but the shape of the array is, then oblige + SsPair' sP (SsArr' SsNone) -> + let STArr sn _ = typeOf e3 + foldex = + elet (mapExt (\_ -> ext) e3) $ + EPair ext + (EShape ext (evar IZ)) + (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1))) + (mapExt (\_ -> ext) (weakenExpr WSink e2)) + (evar IZ)) + in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> + k env1 $ \env' -> + eunPair (mkfoldex env') $ \_ eshape earr -> + EPair ext earr (EBuild ext sn eshape (ENil ext)) + -- If at least some of the additional stores are required, we need to keep this a mapAccum + SsPair' _ (SsArr' sB) -> + -- TODO: propagate usage of primals + occCountX (SsPair SsFull sB) e1 $ \env1_1' mka -> + occEnvPop' env1_1' $ \env1' _ -> + occCountX SsFull e2 $ \env2 mkb -> + occCountX SsFull e3 $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ + EFold1InnerD1 ext cm (mka (OccPush env' () SsFull)) + (mkb env') (mkc env') + + EFold1InnerD2 _ cm ef ebog ed -> + -- TODO: propagate usage of duals + occCountX SsFull ef $ \env1_2' mkef -> + occEnvPop' env1_2' $ \env1_1' _ -> + occEnvPop' env1_1' $ \env1' sB -> + occCountX (SsArr sB) ebog $ \env2 mkebog -> + occCountX SsFull ed $ \env3 mked -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s $ + EFold1InnerD2 ext cm + (mkef (OccPush (OccPush env' () sB) () SsFull)) + (mkebog env') (mked env') + + EConst _ t x -> + k OccEnd $ \_ -> + case s of + SsNone -> ENil ext + SsFull -> EConst ext t x + + EIdx0 _ e -> + occCountX (SsArr s) e $ \env1 mke -> + k env1 $ \env' -> + EIdx0 ext (mke env') + + EIdx1 _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' s' -> + occCountX (SsArr s') a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EIdx1 ext (mka env') (mkb env') + + EIdx _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + _ -> + occCountX (SsArr s) a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EIdx ext (mka env') (mkb env') + + EShape _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + _ -> + occCountX (SsArr SsNone) e $ \env1 mke -> + k env1 $ \env' -> + projectSmallerSubstruc SsFull s $ EShape ext (mke env') + + EOp _ op e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + _ -> + occCountX SsFull e $ \env1 mke -> + k env1 $ \env' -> + projectSmallerSubstruc SsFull s $ EOp ext op (mke env') + + ECustom _ t1 t2 t3 e1 e2 e3 a b + | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 -> + error "Accumulators not allowed in input/output/tape of an ECustom" + | otherwise -> + case s of + SsNone -> + -- Allowed to ignore e1/e2/e3 here because no accumulators are + -- communicated, and hence no relevant effects exist + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + s' -> -- Let's be pessimistic for safety + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s' $ + ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env') + + ERecompute _ e -> + occCountX s e $ \env1 mke -> + k env1 $ \env' -> + ERecompute ext (mke env') + + EWith _ t a b -> + case s of + SsNone -> -- TODO: simplifier should remove accumulations to an unused with, and then remove the with + occCountX SsNone b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s1 -> + withSome (case s1 of + SsFull -> Some SsFull + SsAccum s' -> Some s' + SsNone -> Some SsNone) $ \s1' -> + occCountX s1' a $ \env1 mka -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (EWith ext (applySubstrucM s1' t) (mka env') (mkb (OccPush env' () (SsAccum s1')))) $ + ENil ext + SsPair sB sA -> + occCountX sB b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s1 -> + let s1' = case s1 of + SsFull -> Some SsFull + SsAccum s' -> Some s' + SsNone -> Some SsNone in + withSome (Some sA <> s1') $ \sA' -> + occCountX sA' a $ \env1 mka -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsPair sB sA') (SsPair sB sA) $ + EWith ext (applySubstrucM sA' t) (mka env') (mkb (OccPush env' () (SsAccum sA'))) + SsFull -> occCountX (SsPair SsFull SsFull) topexpr k + + EAccum _ t p a sp b e -> + -- TODO: do better! + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + occCountX SsFull e $ \env3 mke -> + withSome (Some env1 <> Some env2) $ \env12 -> + withSome (Some env12 <> Some env3) $ \env -> + k env $ \env' -> + case s of {SsFull -> id; SsNone -> id} $ + EAccum ext t p (mka env') sp (mkb env') (mke env') + + EZero _ t e -> + occCountX (subZeroInfo s) e $ \env1 mke -> + k env1 $ \env' -> + EZero ext (applySubstrucM s t) (mke env') + where + subZeroInfo :: Substruc t1 t2 -> Substruc (ZeroInfo t1) (ZeroInfo t2) + subZeroInfo SsFull = SsFull + subZeroInfo SsNone = SsNone + subZeroInfo (SsPair s1 s2) = SsPair (subZeroInfo s1) (subZeroInfo s2) + subZeroInfo SsEither{} = error "Either is not a monoid" + subZeroInfo SsLEither{} = SsNone + subZeroInfo SsMaybe{} = SsNone + subZeroInfo (SsArr s') = SsArr (subZeroInfo s') + subZeroInfo SsAccum{} = error "Accum is not a monoid" + + EDeepZero _ t e -> + occCountX (subDeepZeroInfo s) e $ \env1 mke -> + k env1 $ \env' -> + EDeepZero ext (applySubstrucM s t) (mke env') + where + subDeepZeroInfo :: Substruc t1 t2 -> Substruc (DeepZeroInfo t1) (DeepZeroInfo t2) + subDeepZeroInfo SsFull = SsFull + subDeepZeroInfo SsNone = SsNone + subDeepZeroInfo (SsPair s1 s2) = SsPair (subDeepZeroInfo s1) (subDeepZeroInfo s2) + subDeepZeroInfo SsEither{} = error "Either is not a monoid" + subDeepZeroInfo (SsLEither s1 s2) = SsLEither (subDeepZeroInfo s1) (subDeepZeroInfo s2) + subDeepZeroInfo (SsMaybe s') = SsMaybe (subDeepZeroInfo s') + subDeepZeroInfo (SsArr s') = SsArr (subDeepZeroInfo s') + subDeepZeroInfo SsAccum{} = error "Accum is not a monoid" + + EPlus _ t a b -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EPlus ext (applySubstrucM s t) (mka env') (mkb env') + + EOneHot _ t p a b -> + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> -- TODO: do better + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s $ EOneHot ext t p (mka env') (mkb env') + + EError _ t msg -> + k OccEnd $ \_ -> EError ext (applySubstruc s t) msg + where + s = simplifySubstruc (typeOf topexpr) initialS + + handleReduction :: t ~ TArr n (TScal t2) + => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2))) + -> Expr x env (TArr (S n) (TScal t2)) + -> r + handleReduction reduce e + | STArr (SS n) _ <- typeOf e = + case s of + SsNone -> + occCountX SsNone e $ \env mke -> + k env $ \env' -> + use (mke env') $ ENil ext + SsArr' SsNone -> + occCountX (SsArr SsNone) e $ \env mke -> + k env $ \env' -> + elet (mke env') $ + EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) + SsArr' SsFull -> + occCountX (SsArr SsFull) e $ \env mke -> + k env $ \env' -> + reduce (mke env') + + +deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r +deleteUnused SNil (Some OccEnd) k = k SETop +deleteUnused (_ `SCons` env) (Some OccEnd) k = + deleteUnused env (Some OccEnd) $ \sub -> k (SENo sub) +deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k = + deleteUnused env (Some occenv) $ \sub -> + case count of Zero -> k (SENo sub) + _ -> k (SEYesR sub) + +unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t +unsafeWeakenWithSubenv = \sub -> + subst (\x t i -> case sinkViaSubenv i sub of + Just i' -> EVar x t i' + Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") + where + sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) + sinkViaSubenv IZ (SEYesR _) = Just IZ + sinkViaSubenv IZ (SENo _) = Nothing + sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub + sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub diff --git a/src/CHAD/AST/Env.hs b/src/CHAD/AST/Env.hs new file mode 100644 index 0000000..8e6b745 --- /dev/null +++ b/src/CHAD/AST/Env.hs @@ -0,0 +1,95 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.AST.Env where + +import Data.Type.Equality + +import CHAD.AST.Sparse +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types + + +-- | @env'@ is a subset of @env@: each element of @env@ is either included in +-- @env'@ ('SEYes') or not included in @env'@ ('SENo'). +data Subenv' s env env' where + SETop :: Subenv' s '[] '[] + SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env') + SENo :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env' +deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env') + +type Subenv = Subenv' (:~:) +type SubenvS = Subenv' Sparse + +pattern SEYesR :: forall tenv tenv'. () + => forall t env env'. (tenv ~ t : env, tenv' ~ t : env') + => Subenv env env' -> Subenv tenv tenv' +pattern SEYesR s = SEYes Refl s + +{-# COMPLETE SETop, SEYesR, SENo #-} + +subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env' +subList SNil SETop = SNil +subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub) +subList (SCons _ xs) (SENo sub) = subList xs sub + +subenvAll :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env +subenvAll SNil = SETop +subenvAll (SCons t env) = SEYes (subtFull t) (subenvAll env) + +subenvNone :: SList f env -> Subenv' s env '[] +subenvNone SNil = SETop +subenvNone (SCons _ env) = SENo (subenvNone env) + +subenvOnehot :: SList f env -> Idx env t -> s t t' -> Subenv' s env '[t'] +subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env) +subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp) +subenvOnehot SNil i _ = case i of {} + +subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3 +subenvCompose SETop SETop = SETop +subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2) +subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) +subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) + +subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1') +subenvConcat sub1 SETop = sub1 +subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2) +subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2) + +-- subenvSplit :: SList f env1a -> Subenv' s (Append env1a env1b) env2 +-- -> (forall env2a env2b. Subenv' s env1a env2a -> Subenv' s env1b env2b -> r) -> r +-- subenvSplit SNil sub k = k SETop sub +-- subenvSplit (SCons _ list) (SENo sub) k = +-- subenvSplit list sub $ \sub1 sub2 -> +-- k (SENo sub1) sub2 +-- subenvSplit (SCons _ list) (SEYes s sub) k = +-- subenvSplit list sub $ \sub1 sub2 -> +-- k (SEYes s sub1) sub2 + +sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0 +sinkWithSubenv SETop = WId +sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub +sinkWithSubenv (SENo sub) = sinkWithSubenv sub + +wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env +wUndoSubenv SETop = WId +wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub) +wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub + +subenvMap :: (forall a a'. f a -> s a a' -> s' a a') -> SList f env -> Subenv' s env env' -> Subenv' s' env env' +subenvMap _ SNil SETop = SETop +subenvMap f (t `SCons` l) (SEYes s sub) = SEYes (f t s) (subenvMap f l sub) +subenvMap f (_ `SCons` l) (SENo sub) = SENo (subenvMap f l sub) + +subenvD2E :: Subenv env env' -> Subenv (D2E env) (D2E env') +subenvD2E SETop = SETop +subenvD2E (SEYesR sub) = SEYesR (subenvD2E sub) +subenvD2E (SENo sub) = SENo (subenvD2E sub) diff --git a/src/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs index fb5e138..9ddcb35 100644 --- a/src/AST/Pretty.hs +++ b/src/CHAD/AST/Pretty.hs @@ -1,32 +1,31 @@ -{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeOperators #-} -module AST.Pretty (pprintExpr, ppExpr, ppSTy, PrettyX(..)) where +module CHAD.AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where import Control.Monad (ap) import Data.List (intersperse, intercalate) import Data.Functor.Const -import qualified Data.Functor.Product as Product +import Data.Functor.Product qualified as Product import Data.String (fromString) import Prettyprinter import Prettyprinter.Render.String -import qualified Data.Text.Lazy as TL -import qualified Prettyprinter.Render.Terminal as PT +import Data.Text.Lazy qualified as TL +import Prettyprinter.Render.Terminal qualified as PT import System.Console.ANSI (hSupportsANSI) import System.IO (stdout) import System.IO.Unsafe (unsafePerformIO) -import AST -import AST.Count -import CHAD.Types -import Data +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Drev.Types class PrettyX x where @@ -70,6 +69,7 @@ genNameIfUsedIn' prefix ty idx ex _ -> return "_" | otherwise = genName' prefix +-- TODO: let this return a type-tagged thing so that name environments are more typed than Const genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t @@ -145,12 +145,45 @@ ppExpr' d val expr = case expr of EMaybe _ a b e -> do let STMaybe t = typeOf e - a' <- ppExpr' 11 val a + e' <- ppExpr' 0 val e + a' <- ppExpr' 0 val a name <- genNameIfUsedIn t IZ b b' <- ppExpr' 0 (Const name `SCons` val) b + return $ ppParen (d > 0) $ + align $ + group (flatAlt + (annotate AKey (ppString "case") <> ppX expr <+> e' + <> hardline <> annotate AKey (ppString "of")) + (annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of"))) + <> hardline + <> indent 2 + (ppString "Nothing" <+> ppString "->" <+> a' + <> hardline <> ppString "Just" <+> ppString name <+> ppString "->" <+> b') + + ELNil _ _ _ -> return (ppString "LNil") + + ELInl _ _ e -> do e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ - ppApp (ppString "maybe" <> ppX expr) [a', ppLam [ppString name] b', e'] + return $ ppParen (d > 10) $ ppString "LInl" <> ppX expr <+> e' + + ELInr _ _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "LInr" <> ppX expr <+> e' + + ELCase _ e a b c -> do + e' <- ppExpr' 0 val e + let STLEither t1 t2 = typeOf e + a' <- ppExpr' 11 val a + name1 <- genNameIfUsedIn t1 IZ b + b' <- ppExpr' 0 (Const name1 `SCons` val) b + name2 <- genNameIfUsedIn t2 IZ c + c' <- ppExpr' 0 (Const name2 `SCons` val) c + return $ ppParen (d > 0) $ + hang 2 $ + annotate AKey (ppString "lcase") <> ppX expr <+> e' <+> annotate AKey (ppString "of") + <> hardline <> ppString "LNil" <+> ppString "->" <+> a' + <> hardline <> ppString "LInl" <+> ppString name1 <+> ppString "->" <+> b' + <> hardline <> ppString "LInr" <+> ppString name2 <+> ppString "->" <+> c' EConstArr _ _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr @@ -168,16 +201,22 @@ ppExpr' d val expr = case expr of <> hardline <> e') (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e']) + EMap _ a b -> do + let STArr _ t1 = typeOf b + name <- genNameIfUsedIn t1 IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a + b' <- ppExpr' 11 val b + return $ ppParen (d > 0) $ + ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b'] + EFold1Inner _ cm a b c -> do - name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a - name2 <- genNameIfUsedIn (typeOf a) IZ a - a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a + name <- genNameIfUsedIn (STPair (typeOf a) (typeOf a)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c - let opname = case cm of Commut -> "fold1i(C)" - Noncommut -> "fold1i" + let opname = "fold1i" ++ ppCommut cm return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] ESum1Inner _ e -> do e' <- ppExpr' 11 val e @@ -200,6 +239,38 @@ ppExpr' d val expr = case expr of e' <- ppExpr' 11 val e return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e' + EReshape _ n esh e -> do + esh' <- ppExpr' 11 val esh + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e'] + + EZip _ e1 e2 -> do + e1' <- ppExpr' 11 val e1 + e2' <- ppExpr' 11 val e2 + return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2'] + + EFold1InnerD1 _ cm a b c -> do + name <- genNameIfUsedIn (STPair (typeOf b) (typeOf b)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a + b' <- ppExpr' 11 val b + c' <- ppExpr' 11 val c + let opname = "fold1iD1" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] + + EFold1InnerD2 _ cm ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + namef1 <- genNameIfUsedIn tB (IS IZ) ef + namef2 <- genNameIfUsedIn t2 IZ ef + ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef + ebog' <- ppExpr' 11 val ebog + ed' <- ppExpr' 11 val ed + let opname = "fold1iD2" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) + [ppLam [ppString namef1, ppString namef2] ef', ebog', ed'] + EConst _ ty v | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr @@ -255,6 +326,10 @@ ppExpr' d val expr = case expr of ,e1' ,e2'] + ERecompute _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e'] + EWith _ t e1 e2 -> do e1' <- ppExpr' 11 val e1 name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2 @@ -267,27 +342,35 @@ ppExpr' d val expr = case expr of <> hardline <> e2') (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) - EAccum _ _ prj e1 e2 e3 -> do + EAccum _ t prj e1 sp e2 e3 -> do e1' <- ppExpr' 11 val e1 e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "accum") <> ppX expr) [ppString (ppAcPrj prj), e1', e2', e3'] + ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t))) + [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3'] - EZero _ t -> return $ ppParen (d > 0) $ - annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSTy' 11 t + EZero _ t e1 -> do + e1' <- ppExpr' 11 val e1 + return $ ppParen (d > 0) $ + annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' - EPlus _ _ a b -> do + EDeepZero _ t e1 -> do + e1' <- ppExpr' 11 val e1 + return $ ppParen (d > 0) $ + annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' + + EPlus _ t a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "plus") <> ppX expr) [a', b'] + ppApp (annotate AMonoid (ppString "plus") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t) [a', b'] - EOneHot _ _ prj a b -> do + EOneHot _ t prj a b -> do a' <- ppExpr' 11 val a b' <- ppExpr' 11 val b return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "onehot") <> ppX expr) [ppString (ppAcPrj prj), a', b'] + ppApp (annotate AMonoid (ppString "onehot") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), a', b'] EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) @@ -320,14 +403,28 @@ ppLam :: [ADoc] -> ADoc -> ADoc ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"]) <> softline <> body <> ppString ")") -ppAcPrj :: SAcPrj p a b -> String -ppAcPrj SAPHere = "@" -ppAcPrj (SAPFst prj) = "(" ++ ppAcPrj prj ++ ",)" -ppAcPrj (SAPSnd prj) = "(," ++ ppAcPrj prj ++ ")" -ppAcPrj (SAPLeft prj) = "(" ++ ppAcPrj prj ++ "|)" -ppAcPrj (SAPRight prj) = "(|" ++ ppAcPrj prj ++ ")" -ppAcPrj (SAPJust prj) = "J" ++ ppAcPrj prj -ppAcPrj (SAPArrIdx prj n) = "[" ++ ppAcPrj prj ++ "]" ++ intSubscript (fromSNat n) +ppAcPrj :: SMTy a -> SAcPrj p a b -> String +ppAcPrj _ SAPHere = "." +ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)" +ppAcPrj (SMTPair _ t) (SAPSnd prj) = "(," ++ ppAcPrj t prj ++ ")" +ppAcPrj (SMTLEither t _) (SAPLeft prj) = "(" ++ ppAcPrj t prj ++ "|)" +ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")" +ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj +ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n) + +ppSparse :: SMTy a -> Sparse a b -> String +ppSparse t sp | Just Refl <- isDense t sp = "D" +ppSparse _ SpAbsent = "A" +ppSparse t (SpSparse s) = "S" ++ ppSparse t s +ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s +ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s +ppSparse (SMTScal _) SpScal = "." + +ppCommut :: Commutative -> String +ppCommut Commut = "(C)" +ppCommut Noncommut = "" ppX :: PrettyX x => Expr x env t -> ADoc ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) @@ -361,6 +458,7 @@ ppSTy' :: Int -> STy t -> Doc q ppSTy' _ STNil = ppString "1" ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b +ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t ppSTy' d (STArr n t) = ppParen (d > 10) $ ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t @@ -370,7 +468,23 @@ ppSTy' _ (STScal sty) = ppString $ case sty of STF32 -> "f32" STF64 -> "f64" STBool -> "bool" -ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSTy' 11 t +ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t + +ppSMTy :: Int -> SMTy t -> String +ppSMTy d ty = render $ ppSMTy' d ty + +ppSMTy' :: Int -> SMTy t -> Doc q +ppSMTy' _ SMTNil = ppString "1" +ppSMTy' d (SMTPair a b) = ppParen (d > 7) $ ppSMTy' 8 a <> ppString " * " <> ppSMTy' 8 b +ppSMTy' d (SMTLEither a b) = ppParen (d > 6) $ ppSMTy' 7 a <> ppString " ⊕ " <> ppSMTy' 7 b +ppSMTy' d (SMTMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSMTy' 11 t +ppSMTy' d (SMTArr n t) = ppParen (d > 10) $ + ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSMTy' 11 t +ppSMTy' _ (SMTScal sty) = ppString $ case sty of + STI32 -> "i32" + STI64 -> "i64" + STF32 -> "f32" + STF64 -> "f64" ppString :: String -> Doc x ppString = fromString @@ -405,4 +519,5 @@ render = else renderString) . layoutPretty LayoutOptions { layoutPageWidth = AvailablePerLine 120 1.0 } where + {-# NOINLINE stdoutTTY #-} stdoutTTY = unsafePerformIO $ hSupportsANSI stdout diff --git a/src/CHAD/AST/Sparse.hs b/src/CHAD/AST/Sparse.hs new file mode 100644 index 0000000..85f2882 --- /dev/null +++ b/src/CHAD/AST/Sparse.hs @@ -0,0 +1,296 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} + +{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} +module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) where + +import Data.Type.Equality + +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data (SBool(..)) + + +sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' +sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext +sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2 +sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh +sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 = + eunPair e1 $ \w1 e1a e1b -> + eunPair (weakenExpr w1 e2) $ \w2 e2a e2b -> + EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a) + (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b) +sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 = + elet e2 $ + elcase (weakenExpr WSink e1) + (evar IZ) + (elcase (evar (IS IZ)) + (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ)) + (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ))) + (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr")) + (elcase (evar (IS IZ)) + (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ)) + (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll") + (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 = + elet e2 $ + emaybe (weakenExpr WSink e1) + (evar IZ) + (emaybe (evar (IS IZ)) + (EJust ext (evar IZ)) + (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2 +sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 + + +cheapZero :: SMTy t -> Maybe (forall env. Ex env t) +cheapZero SMTNil = Just (ENil ext) +cheapZero (SMTPair t1 t2) + | Just e1 <- cheapZero t1 + , Just e2 <- cheapZero t2 + = Just (EPair ext e1 e2) + | otherwise + = Nothing +cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2)) +cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t)) +cheapZero SMTArr{} = Nothing +cheapZero (SMTScal t) = case t of + STI32 -> Just (EConst ext t 0) + STI64 -> Just (EConst ext t 0) + STF32 -> Just (EConst ext t 0.0) + STF64 -> Just (EConst ext t 0.0) + + +data Injection sp a b where + -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that + -- 'sparsePlusS' can provide injections even if the caller doesn't require + -- them. This simplifies the sparsePlusS code. + Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b + Noinj :: Injection False a b + +withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b' +withInj (Inj f) k = Inj (k f) +withInj Noinj _ = Noinj + +withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2 + -> ((forall e. Ex e a1 -> Ex e b1) + -> (forall e. Ex e a2 -> Ex e b2) + -> (forall e'. Ex e' a' -> Ex e' b')) + -> Injection sp a' b' +withInj2 (Inj f) (Inj g) k = Inj (k f g) +withInj2 Noinj _ _ = Noinj +withInj2 _ Noinj _ = Noinj + +-- | This function produces quadratically-sized code in the presence of nested +-- dynamic sparsity. TODO can this be improved? +sparsePlusS + :: SBool inj1 -> SBool inj2 + -> SMTy t -> Sparse t t1 -> Sparse t t2 + -> (forall t3. Sparse t t3 + -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent) + -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent) + -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) + -> r) + -> r +-- nil override (but don't destroy effects!) +sparsePlusS _ _ SMTNil _ _ k = + k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext) + +-- simplifications +sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k = + sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus -> + k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b) +sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k = + sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus -> + k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext)) + +sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k = + let ta = applySparse sp1 (fromSMTy t) in + sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ))) + minj2 + (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k = + let tb = applySparse sp2 (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) + (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k = + let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in + sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + minj2 + (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k = + let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k = + let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in + sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ))) + minj2 + (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k = + let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ))) + (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) +sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k +sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k + +-- TODO: sparse of Just is just Maybe + +-- dense plus +sparsePlusS _ _ t sp1 sp2 k + | Just Refl <- isDense t sp1 + , Just Refl <- isDense t sp2 + = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b) + +-- handle absents +sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b) +sparsePlusS ST _ t SpAbsent sp2 k + | Just zero2 <- cheapZero (applySparse sp2 t) = + k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b) + | otherwise = + k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b) + +sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a) +sparsePlusS _ ST t sp1 SpAbsent k + | Just zero1 <- cheapZero (applySparse sp1 t) = + k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a) + | otherwise = + k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a) + +-- double sparse yields sparse +sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpSparse sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (emaybe (evar IZ) + (ENothing ext (applySparse sp3 (fromSMTy t))) + (EJust ext (inj2 (evar IZ)))) + (emaybe (evar (IS IZ)) + (EJust ext (inj1 (evar IZ))) + (EJust ext (plus (evar (IS IZ)) (evar IZ))))) + +-- single sparse can yield non-sparse if the other argument is always present +sparsePlusS SF _ t (SpSparse sp1) sp2 k = + sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus -> + k sp3 Noinj (Inj inj2) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (inj2 (evar IZ)) + (plus (evar IZ) (evar (IS IZ)))) +sparsePlusS ST _ t (SpSparse sp1) sp2 k + | Just zero2 <- cheapZero (applySparse sp2 t) = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k sp3 + (Inj $ \a -> emaybe a (inj2 zero2) (inj1 (evar IZ))) + (Inj inj2) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (inj2 (evar IZ)) + (plus (evar IZ) (evar (IS IZ)))) + | otherwise = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpSparse sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> EJust ext (inj2 b)) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (EJust ext (inj2 (evar IZ))) + (EJust ext (plus (evar IZ) (evar (IS IZ))))) +sparsePlusS req1 req2 t sp1 (SpSparse sp2) k = + sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus -> + k sp3 inj2 inj1 (flip plus) + +-- products +sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k = + sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa -> + sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb -> + k (SpPair sp3a sp3b) + (withInj2 minj13a minj13b $ \inj13a inj13b -> + \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b)) + (withInj2 minj23a minj23b $ \inj23a inj23b -> + \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b)) + (\x1 x2 -> + eunPair x1 $ \w1 x1a x1b -> + eunPair (weakenExpr w1 x2) $ \w2 x2a x2b -> + EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b)) + +-- coproducts +sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k = + sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa -> + sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb -> + let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb)) + inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb)) + inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta)) + in + k (SpLEither sp3a sp3b) + (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ)))) + (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ)))) + (\x1 x2 -> + elet x2 $ + elcase (weakenExpr WSink x1) + (elcase (evar IZ) + nil + (inl (inj23a (evar IZ))) + (inr (inj23b (evar IZ)))) + (elcase (evar (IS IZ)) + (inl (inj13a (evar IZ))) + (inl (plusa (evar (IS IZ)) (evar IZ))) + (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr")) + (elcase (evar (IS IZ)) + (inr (inj13b (evar IZ))) + (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll") + (inr (plusb (evar (IS IZ)) (evar IZ))))) + +-- maybe +sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpMaybe sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (emaybe (evar IZ) + (ENothing ext (applySparse sp3 (fromSMTy t))) + (EJust ext (inj2 (evar IZ)))) + (emaybe (evar (IS IZ)) + (EJust ext (inj1 (evar IZ))) + (EJust ext (plus (evar (IS IZ)) (evar IZ))))) + +-- dense array cotangents simply recurse +sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = + sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus -> + k (SpArr sp3) + (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ))) + (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) + (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ)) + (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) + +-- scalars +sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t)) diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs new file mode 100644 index 0000000..8f41ba4 --- /dev/null +++ b/src/CHAD/AST/Sparse/Types.hs @@ -0,0 +1,107 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.AST.Sparse.Types where + +import Data.Kind (Type, Constraint) +import Data.Type.Equality + +import CHAD.AST.Types + + +data Sparse t t' where + SpSparse :: Sparse t t' -> Sparse t (TMaybe t') + SpAbsent :: Sparse t TNil + + SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') + SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') + SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') + SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') + SpScal :: Sparse (TScal t) (TScal t) +deriving instance Show (Sparse t t') + +class ApplySparse f where + applySparse :: Sparse t t' -> f t -> f t' + +instance ApplySparse STy where + applySparse (SpSparse s) t = STMaybe (applySparse s t) + applySparse SpAbsent _ = STNil + applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) + applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) + applySparse SpScal t = t + +instance ApplySparse SMTy where + applySparse (SpSparse s) t = SMTMaybe (applySparse s t) + applySparse SpAbsent _ = SMTNil + applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) + applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) + applySparse SpScal t = t + + +class IsSubType s where + type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint + subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' + subtTrans :: s a b -> s b c -> s a c + subtFull :: IsSubTypeSubject s f => f t -> s t t + +instance IsSubType (:~:) where + type IsSubTypeSubject (:~:) f = () + subtApply = gcastWith + subtTrans = trans + subtFull _ = Refl + +instance IsSubType Sparse where + type IsSubTypeSubject Sparse f = f ~ SMTy + subtApply = applySparse + + subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) + subtTrans _ SpAbsent = SpAbsent + subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) + subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) + subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) + subtTrans SpScal SpScal = SpScal + + subtFull = spDense + +spDense :: SMTy t -> Sparse t t +spDense SMTNil = SpAbsent +spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2) +spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) +spDense (SMTMaybe t) = SpMaybe (spDense t) +spDense (SMTArr _ t) = SpArr (spDense t) +spDense (SMTScal _) = SpScal + +isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') +isDense SMTNil SpAbsent = Just Refl +isDense _ SpSparse{} = Nothing +isDense _ SpAbsent = Nothing +isDense (SMTPair t1 t2) (SpPair s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTLEither t1 t2) (SpLEither s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTMaybe t) (SpMaybe s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTArr _ t) (SpArr s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTScal _) SpScal = Just Refl + +isAbsent :: Sparse t t' -> Bool +isAbsent (SpSparse s) = isAbsent s +isAbsent SpAbsent = True +isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpMaybe s) = isAbsent s +isAbsent (SpArr s) = isAbsent s +isAbsent SpScal = False diff --git a/src/AST/SplitLets.hs b/src/CHAD/AST/SplitLets.hs index dcba1ad..34267e4 100644 --- a/src/AST/SplitLets.hs +++ b/src/CHAD/AST/SplitLets.hs @@ -7,13 +7,13 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -module AST.SplitLets (splitLets) where +module CHAD.AST.SplitLets (splitLets) where import Data.Type.Equality -import AST -import AST.Bindings -import Lemmas +import CHAD.AST +import CHAD.AST.Bindings +import CHAD.Lemmas splitLets :: Ex env t -> Ex env t @@ -22,16 +22,26 @@ splitLets = splitLets' (\t i w -> EVar ext t (w @> i)) splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t splitLets' = \sub -> \case EVar _ t i -> sub t i WId - ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) + ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) ECase x e a b -> let STEither t1 t2 = typeOf e in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b) EMaybe x a b e -> let STMaybe t1 = typeOf e in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e) + ELCase x e a b c -> + let STLEither t1 t2 = typeOf e + in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c) EFold1Inner x cm a b c -> let STArr _ t1 = typeOf c - in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD1 x cm a b c -> + let STArr _ t1 = typeOf c + in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD2 x cm a b c -> + let STArr _ tB = typeOf b + STArr _ t2 = typeOf c + in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c) EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) EFst x e -> EFst x (splitLets' sub e) @@ -41,13 +51,19 @@ splitLets' = \sub -> \case EInr x t e -> EInr x t (splitLets' sub e) ENothing x t -> ENothing x t EJust x e -> EJust x (splitLets' sub e) + ELNil x t1 t2 -> ELNil x t1 t2 + ELInl x t e -> ELInl x t (splitLets' sub e) + ELInr x t e -> ELInr x t (splitLets' sub e) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b) + EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b) ESum1Inner x e -> ESum1Inner x (splitLets' sub e) EUnit x e -> EUnit x (splitLets' sub e) EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) + EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b) + EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (splitLets' sub e) EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) @@ -55,9 +71,11 @@ splitLets' = \sub -> \case EShape x e -> EShape x (splitLets' sub e) EOp x op e -> EOp x op (splitLets' sub e) ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) + ERecompute x e -> ERecompute x (splitLets' sub e) EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) - EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3) - EZero x t -> EZero x t + EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3) + EZero x t ezi -> EZero x t (splitLets' sub ezi) + EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi) EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) EError x t s -> EError x t s @@ -81,15 +99,42 @@ splitLets' = \sub -> \case -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t split2 sub tbind1 tbind2 body = let (ptrs1', bs1') = split @env' tbind1 - bs1 = fst (weakenBindings weakenExpr WSink bs1') + bs1 = fst (weakenBindingsE WSink bs1') (ptrs2, bs2) = split @(bind1 : env') tbind2 in letBinds bs1 $ - letBinds (fst (weakenBindings weakenExpr (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ + letBinds (fst (weakenBindingsE (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1))) _ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env'))) t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w))))) body + -- TODO: abstract this to splitN lol wtf + _split4 :: forall bind1 bind2 bind3 bind4 env' env t. + (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t + _split4 sub tbind1 tbind2 tbind3 tbind4 body = + let (ptrs1, bs1') = split @env' tbind1 + (ptrs2, bs2') = split @(bind1 : env') tbind2 + (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3 + (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4 + bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1') + bs2 = fst (weakenBindingsE (WSink .> WSink) bs2') + bs3 = fst (weakenBindingsE WSink bs3') + b1 = bindingsBinds bs1 + b2 = bindingsBinds bs2 + b3 = bindingsBinds bs3 + b4 = bindingsBinds bs4 + in letBinds bs1 $ + letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $ + letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $ + letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $ + splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1)) + _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink)) + _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink)) + _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env'))) + t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w))))))))) + body + type family Split t where Split (TPair a b) = SplitRec (TPair a b) Split _ = '[] @@ -117,6 +162,7 @@ split typ = case typ of STPair{} -> splitRec (EVar ext typ IZ) typ STNil -> other STEither{} -> other + STLEither{} -> other STMaybe{} -> other STArr{} -> other STScal{} -> other @@ -127,18 +173,19 @@ split typ = case typ of splitRec :: forall env t. Ex env t -> STy t -> (Pointers (Append (SplitRec t) env) t, Bindings Ex env (SplitRec t)) -splitRec rhs = \case +splitRec rhs typ = case typ of STNil -> (PNil, BTop) STPair (a :: STy a) (b :: STy b) | Refl <- lemAppendAssoc @(SplitRec b) @(SplitRec a) @env -> let (p1, bs1) = splitRec (EFst ext rhs) a (p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2) - t@STEither{} -> other t - t@STMaybe{} -> other t - t@STArr{} -> other t - t@STScal{} -> other t - t@STAccum{} -> other t + STEither{} -> other + STLEither{} -> other + STMaybe{} -> other + STArr{} -> other + STScal{} -> other + STAccum{} -> other where - other :: STy t -> (Pointers (t : env) t, Bindings Ex env '[t]) - other t = (Point t IZ, BPush BTop (t, rhs)) + other :: (Pointers (t : env) t, Bindings Ex env '[t]) + other = (Point typ IZ, BPush BTop (typ, rhs)) diff --git a/src/AST/Types.hs b/src/CHAD/AST/Types.hs index b20fc2d..f0feb55 100644 --- a/src/AST/Types.hs +++ b/src/CHAD/AST/Types.hs @@ -5,10 +5,10 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeData #-} -module AST.Types where +module CHAD.AST.Types where import Data.Int (Int32, Int64) import Data.GADT.Compare @@ -16,13 +16,14 @@ import Data.GADT.Show import Data.Kind (Type) import Data.Type.Equality -import Data +import CHAD.Data type data Ty = TNil | TPair Ty Ty | TEither Ty Ty + | TLEither Ty Ty | TMaybe Ty | TArr Nat Ty -- ^ rank, element type | TScal ScalTy @@ -30,15 +31,18 @@ type data Ty type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool +-- | Scalar types happen to be bundled in 'SScalTy' as this is sometimes +-- convenient, but such scalar types are not special in any way. type STy :: Ty -> Type data STy t where STNil :: STy TNil STPair :: STy a -> STy b -> STy (TPair a b) STEither :: STy a -> STy b -> STy (TEither a b) + STLEither :: STy a -> STy b -> STy (TLEither a b) STMaybe :: STy a -> STy (TMaybe a) STArr :: SNat n -> STy t -> STy (TArr n t) STScal :: SScalTy t -> STy (TScal t) - STAccum :: STy t -> STy (TAccum t) + STAccum :: SMTy t -> STy (TAccum t) deriving instance Show (STy t) instance GCompare STy where @@ -49,6 +53,8 @@ instance GCompare STy where STPair{} _ -> GLT ; _ STPair{} -> GGT (STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') STEither{} _ -> GLT ; _ STEither{} -> GGT + (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + STLEither{} _ -> GLT ; _ STLEither{} -> GGT (STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a') STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT (STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') @@ -62,6 +68,45 @@ instance TestEquality STy where testEquality = geq instance GEq STy where geq = defaultGeq instance GShow STy where gshowsPrec = defaultGshowsPrec +-- | Monoid types +type SMTy :: Ty -> Type +data SMTy t where + SMTNil :: SMTy TNil + SMTPair :: SMTy a -> SMTy b -> SMTy (TPair a b) + SMTLEither :: SMTy a -> SMTy b -> SMTy (TLEither a b) + SMTMaybe :: SMTy a -> SMTy (TMaybe a) + SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t) + SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t) +deriving instance Show (SMTy t) + +instance GCompare SMTy where + gcompare = \cases + SMTNil SMTNil -> GEQ + SMTNil _ -> GLT ; _ SMTNil -> GGT + (SMTPair a b) (SMTPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + SMTPair{} _ -> GLT ; _ SMTPair{} -> GGT + (SMTLEither a b) (SMTLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + SMTLEither{} _ -> GLT ; _ SMTLEither{} -> GGT + (SMTMaybe a) (SMTMaybe a') -> gorderingLift1 (gcompare a a') + SMTMaybe{} _ -> GLT ; _ SMTMaybe{} -> GGT + (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') + SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT + (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t') + -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT + +instance TestEquality SMTy where testEquality = geq +instance GEq SMTy where geq = defaultGeq +instance GShow SMTy where gshowsPrec = defaultGshowsPrec + +fromSMTy :: SMTy t -> STy t +fromSMTy = \case + SMTNil -> STNil + SMTPair t1 t2 -> STPair (fromSMTy t1) (fromSMTy t2) + SMTLEither t1 t2 -> STLEither (fromSMTy t1) (fromSMTy t2) + SMTMaybe t -> STMaybe (fromSMTy t) + SMTArr n t -> STArr n (fromSMTy t) + SMTScal sty -> STScal sty + data SScalTy t where STI32 :: SScalTy TI32 STI64 :: SScalTy TI64 @@ -128,14 +173,25 @@ type family ScalIsIntegral t where ScalIsIntegral TBool = False -- | Returns true for arrays /and/ accumulators. -hasArrays :: STy t' -> Bool -hasArrays STNil = False -hasArrays (STPair a b) = hasArrays a || hasArrays b -hasArrays (STEither a b) = hasArrays a || hasArrays b -hasArrays (STMaybe t) = hasArrays t -hasArrays STArr{} = True -hasArrays STScal{} = False -hasArrays STAccum{} = True +typeHasArrays :: STy t' -> Bool +typeHasArrays STNil = False +typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STMaybe t) = typeHasArrays t +typeHasArrays STArr{} = True +typeHasArrays STScal{} = False +typeHasArrays STAccum{} = True + +typeHasAccums :: STy t' -> Bool +typeHasAccums STNil = False +typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STMaybe t) = typeHasAccums t +typeHasAccums STArr{} = False +typeHasAccums STScal{} = False +typeHasAccums STAccum{} = True type family Tup env where Tup '[] = TNil diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs new file mode 100644 index 0000000..d3cad25 --- /dev/null +++ b/src/CHAD/AST/UnMonoid.hs @@ -0,0 +1,252 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where + +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data + + +-- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by +-- expanding them into their concrete implementations. Also ensure that +-- 'EAccum' has a dense sparsity. +unMonoid :: Ex env t -> Ex env t +unMonoid = \case + EZero _ t e -> zero t e + EDeepZero _ t e -> deepZero t e + EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) + EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) + + EVar _ t i -> EVar ext t i + ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) + EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) + EFst _ e -> EFst ext (unMonoid e) + ESnd _ e -> ESnd ext (unMonoid e) + ENil _ -> ENil ext + EInl _ t e -> EInl ext t (unMonoid e) + EInr _ t e -> EInr ext t (unMonoid e) + ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) + ENothing _ t -> ENothing ext t + EJust _ e -> EJust ext (unMonoid e) + EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) + ELNil _ t1 t2 -> ELNil ext t1 t2 + ELInl _ t e -> ELInl ext t (unMonoid e) + ELInr _ t e -> ELInr ext t (unMonoid e) + ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) + EConstArr _ n t x -> EConstArr ext n t x + EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) + EMap _ a b -> EMap ext (unMonoid a) (unMonoid b) + EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) + ESum1Inner _ e -> ESum1Inner ext (unMonoid e) + EUnit _ e -> EUnit ext (unMonoid e) + EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) + EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) + EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) + EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) + EZip _ a b -> EZip ext (unMonoid a) (unMonoid b) + EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) + EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) + EConst _ t x -> EConst ext t x + EIdx0 _ e -> EIdx0 ext (unMonoid e) + EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) + EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b) + EShape _ e -> EShape ext (unMonoid e) + EOp _ op e -> EOp ext op (unMonoid e) + ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) + ERecompute _ e -> ERecompute ext (unMonoid e) + EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) + EAccum _ t p eidx sp eval eacc -> + accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> + acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> + EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) + EError _ t s -> EError ext t s + +zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t +-- don't destroy the effects! +zero SMTNil e = ELet ext e $ ENil ext +zero (SMTPair t1 t2) e = + ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) + (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) +zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2) +zero (SMTMaybe t) _ = ENothing ext (fromSMTy t) +zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e +zero (SMTScal t) _ = case t of + STI32 -> EConst ext STI32 0 + STI64 -> EConst ext STI64 0 + STF32 -> EConst ext STF32 0.0 + STF64 -> EConst ext STF64 0.0 + +deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t +deepZero SMTNil e = elet e $ ENil ext +deepZero (SMTPair t1 t2) e = + ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) + (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) +deepZero (SMTLEither t1 t2) e = + elcase e + (ELNil ext (fromSMTy t1) (fromSMTy t2)) + (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ))) + (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ))) +deepZero (SMTMaybe t) e = + emaybe e + (ENothing ext (fromSMTy t)) + (EJust ext (deepZero t (evar IZ))) +deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e +deepZero (SMTScal t) _ = case t of + STI32 -> EConst ext STI32 0 + STI64 -> EConst ext STI64 0 + STF32 -> EConst ext STF32 0.0 + STF64 -> EConst ext STF64 0.0 + +plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t +-- don't destroy the effects! +plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext +plus (SMTPair t1 t2) a b = + eunPair a $ \w1 a1 a2 -> + eunPair (weakenExpr w1 b) $ \w2 b1 b2 -> + EPair ext (plus t1 (weakenExpr w2 a1) b1) + (plus t2 (weakenExpr w2 a2) b2) +plus (SMTLEither t1 t2) a b = + let t = STLEither (fromSMTy t1) (fromSMTy t2) + in ELet ext a $ + ELet ext (weakenExpr WSink b) $ + ELCase ext (EVar ext t (IS IZ)) + (EVar ext t IZ) + (ELCase ext (EVar ext t (IS IZ)) + (EVar ext t (IS (IS IZ))) + (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ))) + (EError ext t "plus l+r")) + (ELCase ext (EVar ext t (IS IZ)) + (EVar ext t (IS (IS IZ))) + (EError ext t "plus r+l") + (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ)))) +plus (SMTMaybe t) a b = + ELet ext b $ + EMaybe ext + (EVar ext (STMaybe (fromSMTy t)) IZ) + (EJust ext + (EMaybe ext + (EVar ext (fromSMTy t) IZ) + (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) + (EVar ext (STMaybe (fromSMTy t)) (IS IZ)))) + (weakenExpr WSink a) +plus (SMTArr _ t) a b = + ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) + a b +plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) + +onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t +onehot typ topprj idx arg = case (typ, topprj) of + (_, SAPHere) -> + ELet ext arg $ + EVar ext (fromSMTy typ) IZ + + (SMTPair t1 t2, SAPFst prj) -> + ELet ext idx $ + let tidx = typeOf idx in + ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ + let toh = fromSMTy t1 in + EPair ext (EVar ext toh IZ) + (zero t2 (ESnd ext (EVar ext tidx (IS IZ)))) + + (SMTPair t1 t2, SAPSnd prj) -> + ELet ext idx $ + let tidx = typeOf idx in + ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ + let toh = fromSMTy t2 in + EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ)))) + (EVar ext toh IZ) + + (SMTLEither t1 t2, SAPLeft prj) -> + ELInl ext (fromSMTy t2) (onehot t1 prj idx arg) + (SMTLEither t1 t2, SAPRight prj) -> + ELInr ext (fromSMTy t1) (onehot t2 prj idx arg) + + (SMTMaybe t1, SAPJust prj) -> + EJust ext (onehot t1 prj idx arg) + + (SMTArr n t1, SAPArrIdx prj) -> + let tidx = tTup (sreplicate n tIx) + in ELet ext idx $ + EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $ + eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) + (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) + (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ + zero t1 (EVar ext (tZeroInfo t1) IZ)) + +accumulateSparse + :: SMTy t -> Sparse t t' -> Ex env t' + -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil) + -> Ex env TNil +accumulateSparse topty topsp arg accum = case (topty, topsp) of + (_, s) | Just Refl <- isDense topty s -> + accum WId SAPHere (ENil ext) arg + (SMTScal _, SpScal) -> + accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh + (_, SpSparse s) -> + emaybe arg + (ENil ext) + (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w))) + (_, SpAbsent) -> + ENil ext + (SMTPair t1 t2, SpPair s1 s2) -> + eunPair arg $ \w1 e1 e2 -> + elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ + accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) + (SMTLEither t1 t2, SpLEither s1 s2) -> + elcase arg + (ENil ext) + (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) + (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) + (SMTMaybe t, SpMaybe s) -> + emaybe arg + (ENil ext) + (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) + (SMTArr n t, SpArr s) -> + let tn = tTup (sreplicate n tIx) in + elet arg $ + elet (EBuild ext n (EShape ext (evar IZ)) $ + accumulateSparse t s + (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) + (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ + ENil ext + +acPrjCompose + :: SAIDense dense + -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) + -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b) + -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r +acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2 +acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k = + acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPFst p') idx' +acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k = + acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPSnd p') idx' +acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ))) +acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx') +acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPLeft p') idx' +acPrjCompose d (SAPRight p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPRight p') idx' +acPrjCompose d (SAPJust p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPJust p') idx' +acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') +acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') diff --git a/src/AST/Weaken.hs b/src/CHAD/AST/Weaken.hs index d882e28..ac0d152 100644 --- a/src/AST/Weaken.hs +++ b/src/CHAD/AST/Weaken.hs @@ -15,14 +15,15 @@ -- The reason why this is a separate module with "little" in it: {-# LANGUAGE AllowAmbiguousTypes #-} -module AST.Weaken (module AST.Weaken, Append) where +module CHAD.AST.Weaken (module CHAD.AST.Weaken, Append) where import Data.Bifunctor (first) import Data.Functor.Const +import Data.GADT.Compare import Data.Kind (Type) -import Data -import Lemmas +import CHAD.Data +import CHAD.Lemmas type Idx :: [k] -> k -> Type @@ -31,6 +32,11 @@ data Idx env t where IS :: Idx env t -> Idx (a : env) t deriving instance Show (Idx env t) +instance GEq (Idx env) where + geq IZ IZ = Just Refl + geq (IS i) (IS j) | Just Refl <- geq i j = Just Refl + geq _ _ = Nothing + splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t) splitIdx SNil i = Right i splitIdx (SCons _ _) IZ = Left IZ @@ -123,7 +129,7 @@ wCopies bs w = let bs' = slistMap (\_ -> Const ()) bs in WStack bs' bs' WId w -wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env +wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env wRaiseAbove SNil _ = WClosed wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) diff --git a/src/AST/Weaken/Auto.hs b/src/CHAD/AST/Weaken/Auto.hs index 6752c24..229940b 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/CHAD/AST/Weaken/Auto.hs @@ -1,35 +1,34 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeAbstractions #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS_GHC -Wno-partial-type-signatures #-} -module AST.Weaken.Auto ( +module CHAD.AST.Weaken.Auto ( autoWeak, (&.), auto, auto1, Layout(..), ) where import Data.Functor.Const +import Data.Kind (Constraint) import GHC.OverloadedLabels import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) -import AST.Weaken -import Data -import Lemmas +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Lemmas type family Lookup name list where @@ -39,18 +38,21 @@ type family Lookup name list where -- | The @withPre@ type parameter indicates whether there can be 'LPreW' --- occurrences within this layout. -data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where - LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments (Lookup name segments) +-- occurrences within this layout. 'names' is the list of names that this +-- layout /produces/. That is: for LPreW, it contains the target name. The +-- 'names' list of a source layout must be a subset of the names list of the +-- target layout (which cannot contain LPreW); this is checked with SubLayout. +data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (names :: [Symbol]) (env :: [t]) where + LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments '[name] (Lookup name segments) -- | Pre-weaken with a weakening LPreW :: forall name1 name2 segments. SegmentName name1 -> SegmentName name2 -> Lookup name1 segments :> Lookup name2 segments - -> Layout True segments (Lookup name1 segments) - (:++:) :: Layout withPre segments env1 -> Layout withPre segments env2 -> Layout withPre segments (Append env1 env2) + -> Layout True segments '[name2] (Lookup name1 segments) + (:++:) :: Layout withPre segments names1 env1 -> Layout withPre segments names2 env2 -> Layout withPre segments (Append names1 names2) (Append env1 env2) infixr :++: -instance (KnownSymbol name, seg ~ Lookup name segments) => IsLabel name (Layout withPre segments seg) where +instance (KnownSymbol name, seg ~ Lookup name segments, names ~ '[name]) => IsLabel name (Layout withPre segments names seg) where fromLabel = LSeg (symbolSing @name) newtype SegmentName name = SegmentName (SSymbol name) @@ -60,11 +62,23 @@ instance (KnownSymbol name, name ~ name') => IsLabel name (SegmentName name') wh fromLabel = SegmentName symbolSing +type family SubLayout names1 names2 where + SubLayout '[] _ = () :: Constraint + SubLayout (n : names1) names2 = SubLayout' n (Contains n names2) names1 names2 +type family SubLayout' n ok names1 names2 where + SubLayout' n False _ _ = TypeError (Text "The name '" :<>: Text n :<>: Text "' appears in the source layout but not in the target.") + SubLayout' _ True names1 names2 = SubLayout names1 names2 +type family Contains n names where + Contains _ '[] = False + Contains n (n : _) = True + Contains n (_ : names) = Contains n names + + data SSegments (segments :: [(Symbol, [t])]) where SSegNil :: SSegments '[] SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) -instance (KnownSymbol name, name ~ name', segs ~ '[ '(name', ts)]) => IsLabel name (SList f ts -> SSegments segs) where +instance (KnownSymbol name, segs ~ '[ '(name, ts)]) => IsLabel name (SList f ts -> SSegments segs) where fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil auto :: KnownListSpine list => SList (Const ()) list @@ -74,7 +88,7 @@ auto1 :: SList (Const ()) '[t] auto1 = Const () `SCons` SNil infixr &. -(&.) :: SSegments segs1 -> SSegments segs2 -> SSegments (Append segs1 segs2) +(&.) :: SSegments '[segs1] -> SSegments segs2 -> SSegments (segs1 : segs2) (&.) = ssegmentsAppend where ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b) @@ -118,12 +132,12 @@ linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout | Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2 = LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2) -lineariseLayout :: Layout withPre segments env -> LinLayout withPre segments env -lineariseLayout (LSeg name :: Layout _ _ seg) +lineariseLayout :: Layout withPre segments names env -> LinLayout withPre segments env +lineariseLayout (LSeg name :: Layout _ _ _ seg) | Refl <- lemAppendNil @seg = LinApp name LinEnd lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2 -lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ seg) +lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ _ seg) | Refl <- lemAppendNil @seg = LinAppPreW name1 name2 w LinEnd @@ -151,8 +165,7 @@ pullDown segs name@SSymbol linlayout kNotFound k = k (LinApp n' lin') (WSwap @env' (segmentLookup segs n') (segmentLookup segs name) .> wCopies (segmentLookup segs n') w) -sortLinLayouts :: forall segments env1 env2. - SSegments segments +sortLinLayouts :: SSegments segments -> LinLayout False segments env1 -> LinLayout False segments env2 -> env1 :> env2 sortLinLayouts _ LinEnd LinEnd = WId sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2) @@ -169,8 +182,8 @@ sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail sortLinLayouts _ LinEnd LinApp{} = WClosed sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target" -autoWeak :: forall segments env1 env2. - SSegments segments -> Layout True segments env1 -> Layout False segments env2 -> env1 :> env2 +autoWeak :: SubLayout names1 names2 + => SSegments segments -> Layout True segments names1 env1 -> Layout False segments names2 env2 -> env1 :> env2 autoWeak segs ly1 ly2 = preWeaken segs (lineariseLayout ly1) $ \wPreweak lin1 -> sortLinLayouts segs lin1 (lineariseLayout ly2) .> wPreweak diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs deleted file mode 100644 index b61b5ff..0000000 --- a/src/CHAD/Accum.hs +++ /dev/null @@ -1,27 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -module CHAD.Accum where - -import AST -import CHAD.Types -import Data - - - -makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) -makeAccumulators SNil e = e -makeAccumulators (t `SCons` envpro) e = - makeAccumulators envpro $ - EWith ext t (EZero ext t) e - -uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) -uninvertTup SNil _ e = EPair ext e (ENil ext) -uninvertTup (t `SCons` list) tcore e = - ELet ext (uninvertTup list (STPair tcore t) e) $ - let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding - in EPair ext - (EFst ext (EFst ext (EVar ext recT IZ))) - (EPair ext - (ESnd ext (EVar ext recT IZ)) - (ESnd ext (EFst ext (EVar ext recT IZ)))) - diff --git a/src/Analysis/Identity.hs b/src/CHAD/Analysis/Identity.hs index f34bfbc..212cc7d 100644 --- a/src/Analysis/Identity.hs +++ b/src/CHAD/Analysis/Identity.hs @@ -3,7 +3,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} -module Analysis.Identity ( +module CHAD.Analysis.Identity ( identityAnalysis, identityAnalysis', ValId(..), @@ -13,11 +13,11 @@ module Analysis.Identity ( import Data.Foldable (toList) import Data.List (intercalate) -import AST -import AST.Pretty (PrettyX(..)) -import CHAD.Types (d1, d2) -import Data -import Util.IdGen +import CHAD.AST +import CHAD.AST.Pretty (PrettyX(..)) +import CHAD.Data +import CHAD.Drev.Types (d1, d2) +import CHAD.Util.IdGen -- | Every array, scalar and accumulator has an ID. Trivial values such as @@ -28,6 +28,7 @@ data ValId t where VIPair :: ValId a -> ValId b -> ValId (TPair a b) VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b) -- ^ known alternative VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case + VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b) VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a) VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value VIArr :: Int -> Vec n Int -> ValId (TArr n t) @@ -45,6 +46,13 @@ instance PrettyX ValId where VIMaybe Nothing -> "N" VIMaybe (Just a) -> 'J' : prettyX a VIMaybe' a -> 'M' : prettyX a + VILEither (VIMaybe Nothing) -> "lN" + VILEither (VIMaybe (Just (VIEither (Left a)))) -> "(lL" ++ prettyX a ++ ")" + VILEither (VIMaybe (Just (VIEither (Right a)))) -> "(lR" ++ prettyX a ++ ")" + VILEither (VIMaybe (Just (VIEither' a b))) -> "(" ++ prettyX a ++ "⊕" ++ prettyX b ++ ")" + VILEither (VIMaybe' (VIEither (Left a))) -> "(mlL" ++ prettyX a ++ ")" + VILEither (VIMaybe' (VIEither (Right a))) -> "(mlR" ++ prettyX a ++ ")" + VILEither (VIMaybe' (VIEither' a b)) -> "(m(" ++ prettyX a ++ "⊕" ++ prettyX b ++ "))" VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]" VIScal i -> show i VIAccum i -> 'C' : show i @@ -147,6 +155,42 @@ idana env expr = case expr of res <- unify v1 v2 pure (res, EMaybe res e1' e2' e3') + ELNil _ t1 t2 -> do + let v = VILEither (VIMaybe Nothing) + pure (v, ELNil v t1 t2) + + ELInl _ t2 e1 -> do + (v1, e1') <- idana env e1 + let v = VILEither (VIMaybe (Just (VIEither (Left v1)))) + pure (v, ELInl v t2 e1') + + ELInr _ t1 e2 -> do + (v2, e2') <- idana env e2 + let v = VILEither (VIMaybe (Just (VIEither (Right v2)))) + pure (v, ELInr v t1 e2') + + ELCase _ e1 e2 e3 e4 -> do + let STLEither t1 t2 = typeOf e1 + (v1L, e1') <- idana env e1 + let VILEither v1 = v1L + let go mv1'l mv1'r f = do + v1'l <- maybe (genIds t1) pure mv1'l + v1'r <- maybe (genIds t2) pure mv1'r + (v2, e2') <- idana env e2 + (v3, e3') <- idana (v1'l `SCons` env) e3 + (v4, e4') <- idana (v1'r `SCons` env) e4 + res <- f v2 v3 v4 + pure (res, ELCase res e1' e2' e3' e4') + case v1 of + VIMaybe Nothing -> go Nothing Nothing (\v2 _ _ -> pure v2) + VIMaybe (Just (VIEither (Left v1'))) -> go (Just v1') Nothing (\_ v3 _ -> pure v3) + VIMaybe (Just (VIEither (Right v1'))) -> go Nothing (Just v1') (\_ _ v4 -> pure v4) + VIMaybe (Just (VIEither' v1'l v1'r)) -> go (Just v1'l) (Just v1'r) (\_ v3 v4 -> unify v3 v4) + VIMaybe' (VIEither (Left v1')) -> go (Just v1') Nothing (\v2 v3 _ -> unify v2 v3) + VIMaybe' (VIEither (Right v1')) -> go Nothing (Just v1') (\v2 _ v4 -> unify v2 v4) + VIMaybe' (VIEither' v1'l v1'r) -> + go (Just v1'l) (Just v1'r) (\v2 v3 v4 -> unify v2 =<< unify v3 v4) + EConstArr _ dim t arr -> do x1 <- VIArr <$> genId <*> vecReplicateA dim genId pure (x1, EConstArr x1 dim t arr) @@ -158,11 +202,19 @@ idana env expr = case expr of res <- VIArr <$> genId <*> shidsToVec dim shids pure (res, EBuild res dim e1' e2') + EMap _ e1 e2 -> do + let STArr _ t = typeOf e2 + x1 <- genIds t + (_, e1') <- idana (x1 `SCons` env) e1 + (v2, e2') <- idana env e2 + let VIArr _ sh = v2 + res <- VIArr <$> genId <*> pure sh + pure (res, EMap res e1' e2') + EFold1Inner _ cm e1 e2 e3 -> do let t1 = typeOf e1 - x1 <- genIds t1 - x2 <- genIds t1 - (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1 + x1 <- genIds (STPair t1 t1) + (_, e1') <- idana (x1 `SCons` env) e1 (_, e2') <- idana env e2 (v3, e3') <- idana env e3 let VIArr _ (_ :< sh) = v3 @@ -200,6 +252,41 @@ idana env expr = case expr of res <- VIArr <$> genId <*> pure sh pure (res, EMinimum1Inner res e1') + EReshape _ dim e1 e2 -> do + (v1, e1') <- idana env e1 + (_, e2') <- idana env e2 + res <- VIArr <$> genId <*> shidsToVec dim v1 + pure (res, EReshape res dim e1' e2') + + EZip _ e1 e2 -> do + (v1, e1') <- idana env e1 + (_, e2') <- idana env e2 + let VIArr _ sh = v1 + res <- VIArr <$> genId <*> pure sh + pure (res, EZip res e1' e2') + + EFold1InnerD1 _ cm e1 e2 e3 -> do + let t1 = typeOf e2 + x1 <- genIds (STPair t1 t1) + (_, e1') <- idana (x1 `SCons` env) e1 + (_, e2') <- idana env e2 + (v3, e3') <- idana env e3 + let VIArr _ sh'@(_ :< sh) = v3 + res <- VIPair <$> (VIArr <$> genId <*> pure sh) <*> (VIArr <$> genId <*> pure sh') + pure (res, EFold1InnerD1 res cm e1' e2' e3') + + EFold1InnerD2 _ cm ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + xf1 <- genIds t2 + xf2 <- genIds tB + (_, e1') <- idana (xf1 `SCons` xf2 `SCons` env) ef + (v2, e2') <- idana env ebog + (_, e3') <- idana env ed + let VIArr _ sh@(_ :< sh') = v2 + res <- VIPair <$> (VIArr <$> genId <*> pure sh') <*> (VIArr <$> genId <*> pure sh) + pure (res, EFold1InnerD2 res cm e1' e2' e3') + EConst _ t val -> do res <- VIScal <$> genId pure (res, EConst res t val) @@ -250,6 +337,10 @@ idana env expr = case expr of res <- genIds t4 pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5') + ERecompute _ e -> do + (v, e') <- idana env e + pure (v, ERecompute v e') + EWith _ t e1 e2 -> do let t1 = typeOf e1 (_, e1') <- idana env e1 @@ -259,26 +350,36 @@ idana env expr = case expr of let res = VIPair v2 x2 pure (res, EWith res t e1' e2') - EAccum _ t prj e1 e2 e3 -> do + EAccum _ t prj e1 sp e2 e3 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 (_, e3') <- idana env e3 - pure (VINil, EAccum VINil t prj e1' e2' e3') + pure (VINil, EAccum VINil t prj e1' sp e2' e3') - EZero _ t -> do - res <- genIds (d2 t) - pure (res, EZero res t) + EZero _ t e1 -> do + -- Approximate the result of EZero to be independent from the zero info + -- expression; not quite true for shape variables + (_, e1') <- idana env e1 + res <- genIds (fromSMTy t) + pure (res, EZero res t e1') + + EDeepZero _ t e1 -> do + -- Approximate the result of EDeepZero to be independent from the zero info + -- expression; not quite true for shape variables + (_, e1') <- idana env e1 + res <- genIds (fromSMTy t) + pure (res, EDeepZero res t e1') EPlus _ t e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 - res <- genIds (d2 t) + res <- genIds (fromSMTy t) pure (res, EPlus res t e1' e2') EOneHot _ t i e1 e2 -> do (_, e1') <- idana env e1 (_, e2') <- idana env e2 - res <- genIds (d2 t) + res <- genIds (fromSMTy t) pure (res, EOneHot res t i e1' e2') EError _ t s -> do @@ -307,6 +408,7 @@ unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b +unify (VILEither a) (VILEither b) = VILEither <$> unify a b unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j @@ -319,6 +421,7 @@ genIds :: STy t -> IdGen (ValId t) genIds STNil = pure VINil genIds (STPair a b) = VIPair <$> genIds a <*> genIds b genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b +genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b) genIds (STMaybe t) = VIMaybe' <$> genIds t genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId genIds STScal{} = VIScal <$> genId diff --git a/src/Array.hs b/src/CHAD/Array.hs index 707dce2..caf63ef 100644 --- a/src/Array.hs +++ b/src/CHAD/Array.hs @@ -2,19 +2,20 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TupleSections #-} -module Array where +module CHAD.Array where import Control.DeepSeq import Control.Monad.Trans.State.Strict import Data.Foldable (traverse_) import Data.Vector (Vector) -import qualified Data.Vector as V +import Data.Vector qualified as V import GHC.Generics (Generic) -import Data +import CHAD.Data data Shape n where @@ -91,6 +92,11 @@ arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l) arrayToList :: Array n t -> [t] arrayToList (Array _ v) = V.toList v +arrayReshape :: Shape n -> Array m t -> Array n t +arrayReshape sh (Array sh' v) + | shapeSize sh == shapeSize sh' = Array sh v + | otherwise = error $ "arrayReshape: different shape size than original (" ++ show sh' ++ " -> " ++ show sh ++ ")" + arrayUnit :: t -> Array Z t arrayUnit x = Array ShNil (V.singleton x) diff --git a/src/Compile.hs b/src/CHAD/Compile.hs index e2d004a..44a335c 100644 --- a/src/Compile.hs +++ b/src/CHAD/Compile.hs @@ -2,13 +2,14 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} -module Compile (compile) where +module CHAD.Compile (compile, compileStderr) where import Control.Applicative (empty) import Control.Monad (forM_, when, replicateM) @@ -20,36 +21,37 @@ import Data.Bifunctor (first) import Data.Char (ord) import Data.Foldable (toList) import Data.Functor.Const -import qualified Data.Functor.Product as Product +import Data.Functor.Product qualified as Product import Data.Functor.Product (Product) import Data.IORef import Data.List (foldl1', intersperse, intercalate) -import qualified Data.Map.Strict as Map +import Data.Map.Strict qualified as Map import Data.Maybe (fromMaybe) -import qualified Data.Set as Set +import Data.Set qualified as Set import Data.Set (Set) import Data.Some -import qualified Data.Vector as V +import Data.Vector qualified as V import Foreign import GHC.Exts (int2Word#, addr2Int#) import GHC.Num (integerFromWord#) import GHC.Ptr (Ptr(..)) +import GHC.Stack (HasCallStack) import Numeric (showHex) import System.IO (hPutStrLn, stderr) import System.IO.Error (mkIOError, userErrorType) import System.IO.Unsafe (unsafePerformIO) import Prelude hiding ((^)) -import qualified Prelude +import Prelude qualified -import Array -import AST -import AST.Pretty (ppSTy, ppExpr) -import qualified CHAD.Types as CHAD -import Compile.Exec -import Data -import Interpreter.Rep -import qualified Util.IdGen as IdGen +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty (ppSTy, ppExpr) +import CHAD.AST.Sparse.Types (isDense) +import CHAD.Compile.Exec +import CHAD.Data +import CHAD.Interpreter.Rep +import CHAD.Util.IdGen qualified as IdGen -- In shape and index arrays, the innermost dimension is on the right (last index). @@ -70,28 +72,30 @@ debugAllocs :: Bool; debugAllocs = toEnum 0 -- | Emit extra C code that checks stuff emitChecks :: Bool; emitChecks = toEnum 0 +-- | Returns compiled function plus compilation output (warnings) compile :: SList STy env -> Ex env t - -> IO (SList Value env -> IO (Rep t)) + -> IO (SList Value env -> IO (Rep t), String) compile = \env expr -> do codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i)) let (source, offsets) = compileToString codeID env expr when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>" when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>" - lib <- buildKernel source ["kernel"] + (lib, compileOutput) <- buildKernel source "kernel" let result_type = typeOf expr result_size = sizeofSTy result_type - return $ \val -> do - allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do - let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) - serialiseArguments args ptr $ do - callKernelFun "kernel" lib ptr - ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) - when (ok /= 1) $ - ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) - deserialise result_type ptr (koResultOffset offsets) + let function val = do + allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do + let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) + serialiseArguments args ptr $ do + callKernelFun lib ptr + ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) + when (ok /= 1) $ + ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) + deserialise result_type ptr (koResultOffset offsets) + return (function, compileOutput) where serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k = @@ -99,6 +103,15 @@ compile = \env expr -> do serialiseArguments args ptr k serialiseArguments _ _ k = k +-- | 'compile', but writes any produced C compiler output to stderr. +compileStderr :: SList STy env -> Ex env t + -> IO (SList Value env -> IO (Rep t)) +compileStderr env expr = do + (fun, output) <- compile env expr + when (not (null output)) $ + hPutStrLn stderr $ "[chad] Kernel compilation GCC output: <<<\n" ++ output ++ ">>>" + return fun + data StructDecl = StructDecl String -- ^ name @@ -126,7 +139,7 @@ data CExpr | CECall String [CExpr] -- ^ function(arg1, ..., argn) | CEBinop CExpr String CExpr -- ^ expr + expr | CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr - | CECast String CExpr -- ^ (<type)<expr> + | CECast String CExpr -- ^ (<type>)<expr> deriving (Show) printStructDecl :: StructDecl -> ShowS @@ -215,75 +228,88 @@ repSTy (STScal st) = case st of STBool -> "uint8_t" repSTy t = genStructName t -genStructName :: STy t -> String -genStructName = \t -> "ty_" ++ gen t where - -- all tags start with a letter, so the array mangling is unambiguous. - gen :: STy t -> String - gen STNil = "n" - gen (STPair a b) = 'P' : gen a ++ gen b - gen (STEither a b) = 'E' : gen a ++ gen b - gen (STMaybe t) = 'M' : gen t - gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t - gen (STScal st) = case st of - STI32 -> "i" - STI64 -> "j" - STF32 -> "f" - STF64 -> "d" - STBool -> "b" - gen (STAccum t) = 'C' : gen t +genStructName, genArrBufStructName :: STy t -> String +(genStructName, genArrBufStructName) = + (\t -> "ty_" ++ gen t + ,\case STArr _ t -> "ty_A_" ++ gen t ++ "_buf" -- just like the normal type, but with _ for the dimension + t -> error $ "genArrBufStructName: not an array type: " ++ show t) + where + -- all tags start with a letter, so the array mangling is unambiguous. + gen :: STy t -> String + gen STNil = "n" + gen (STPair a b) = 'P' : gen a ++ gen b + gen (STEither a b) = 'E' : gen a ++ gen b + gen (STLEither a b) = 'L' : gen a ++ gen b + gen (STMaybe t) = 'M' : gen t + gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t + gen (STScal st) = case st of + STI32 -> "i" + STI64 -> "j" + STF32 -> "f" + STF64 -> "d" + STBool -> "b" + gen (STAccum t) = 'C' : gen (fromSMTy t) + +-- The subtrees contain structs used in the bodies of the structs in this node. +data StructTree = TreeNode [StructDecl] [StructTree] + deriving (Show) -- | This function generates the actual struct declarations for each of the -- types in our language. It thus implicitly "documents" the layout of the -- types in the C translation. -genStruct :: String -> STy t -> [StructDecl] -genStruct name topty = case topty of +-- +-- For accumulation it is important that for struct representations of monoid +-- types, the all-zero-bytes value corresponds to the zero value of that type. +buildStructTree :: STy t -> StructTree +buildStructTree topty = case topty of STNil -> - [StructDecl name "" com] + TreeNode [StructDecl name "" com] [] STPair a b -> - [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] + TreeNode [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] + [buildStructTree a, buildStructTree b] STEither a b -> -- 0 -> l, 1 -> r - [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + [buildStructTree a, buildStructTree b] + STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r + TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + [buildStructTree a, buildStructTree b] STMaybe t -> -- 0 -> nothing, 1 -> just - [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] + TreeNode [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] + [buildStructTree t] STArr n t -> -- The buffer is trailed by a VLA for the actual array data. - [StructDecl (name ++ "_buf") ("size_t sh[" ++ show (fromSNat n) ++ "]; size_t refc; " ++ repSTy t ++ " xs[];") "" - ,StructDecl name (name ++ "_buf *buf;") com] + -- TODO: no buffer if n = 0 + TreeNode [StructDecl (genArrBufStructName topty) ("size_t refc; " ++ repSTy t ++ " xs[];") "" + ,StructDecl name (genArrBufStructName topty ++ " *buf; size_t sh[" ++ show (fromSNat n) ++ "];") com] + [buildStructTree t] STScal _ -> - [] + TreeNode [] [] STAccum t -> - [StructDecl (name ++ "_buf") (repSTy (CHAD.d2 t) ++ " ac;") "" - ,StructDecl name (name ++ "_buf *buf;") com] + TreeNode [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" + ,StructDecl name (name ++ "_buf *buf;") com] + [buildStructTree (fromSMTy t)] where + name = genStructName topty com = ppSTy 0 topty -- State: already-generated (skippable) struct names -- Writer: the structs in declaration order -genStructs :: STy t -> WriterT (Bag StructDecl) (State (Set String)) () -genStructs ty = do - let name = genStructName ty - seen <- lift $ gets (name `Set.member`) - - if seen - then pure () - else do - -- already mark this struct as generated now, so we don't generate it - -- twice (unnecessary because no recursive types, but y'know) - lift $ modify (Set.insert name) - - () <- case ty of - STNil -> pure () - STPair a b -> genStructs a >> genStructs b - STEither a b -> genStructs a >> genStructs b - STMaybe t -> genStructs t - STArr _ t -> genStructs t - STScal _ -> pure () - STAccum t -> genStructs (CHAD.d2 t) - - tell (BList (genStruct name ty)) +genStructTreeW :: StructTree -> WriterT (Bag StructDecl) (State (Set String)) () +genStructTreeW (TreeNode these deps) = do + seen <- lift get + case filter ((`Set.notMember` seen) . nameOf) these of + [] -> pure () + structs -> do + lift $ modify (Set.fromList (map nameOf structs) <>) + mapM_ genStructTreeW deps + tell (BList structs) + where + nameOf (StructDecl name _ _) = name genAllStructs :: Foldable t => t (Some STy) -> [StructDecl] -genAllStructs tys = toList $ evalState (execWriterT (mapM_ (\(Some t) -> genStructs t) tys)) mempty +genAllStructs tys = + let m = mapM_ (\(Some t) -> genStructTreeW (buildStructTree t)) tys + in toList (evalState (execWriterT m) mempty) data CompState = CompState { csStructs :: Set (Some STy) @@ -332,6 +358,12 @@ emitStruct ty = CompM $ do modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } return (genStructName ty) +-- | Also returns the name of the array buffer struct +emitArrStruct :: STy t -> CompM (String, String) +emitArrStruct ty = CompM $ do + modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } + return (genStructName ty, genArrBufStructName ty) + emitTLD :: String -> CompM () emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } @@ -419,10 +451,10 @@ compileToString codeID env expr = else id ,showString $ " const bool success = typed_kernel(" ++ "\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++ - concat (map (\((arg, typ), off) -> - ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" - ++ " /* " ++ arg ++ " */") - (zip arg_pairs arg_offsets)) ++ + concat (zipWith (\(arg, typ) off -> + ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" + ++ " /* " ++ arg ++ " */") + arg_pairs arg_offsets) ++ "\n );\n" ,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n" ,if debugRefc then showString " fprintf(stderr, PRTAG \"Return\\n\");\n" @@ -450,11 +482,20 @@ serialise topty topval ptr off k = serialise a x ptr off $ serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k (STEither a _, Left x) -> do - pokeByteOff ptr off (0 :: Word8) -- alignment of (a + b) is alignment of (union {a b}) + pokeByteOff ptr off (0 :: Word8) -- alignment of (union {a b}) is the same as alignment of (a + b) serialise a x ptr (off + alignmentSTy topty) k (STEither _ b, Right y) -> do pokeByteOff ptr off (1 :: Word8) serialise b y ptr (off + alignmentSTy topty) k + (STLEither _ _, Nothing) -> do + pokeByteOff ptr off (0 :: Word8) + k + (STLEither a _, Just (Left x)) -> do + pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b) + serialise a x ptr (off + alignmentSTy topty) k + (STLEither _ b, Just (Right y)) -> do + pokeByteOff ptr off (2 :: Word8) + serialise b y ptr (off + alignmentSTy topty) k (STMaybe _, Nothing) -> do pokeByteOff ptr off (0 :: Word8) k @@ -463,19 +504,18 @@ serialise topty topval ptr off k = serialise t x ptr (off + alignmentSTy t) k (STArr n t, Array sh vec) -> do let eltsz = sizeofSTy t - allocaBytes (fromSNat n * 8 + 8 + shapeSize sh * eltsz) $ \bufptr -> do + allocaBytes (8 + shapeSize sh * eltsz) $ \bufptr -> do when debugRefc $ hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr pokeByteOff ptr off bufptr + pokeShape ptr (off + 8) n sh - pokeShape bufptr 0 n sh - pokeByteOff @Word64 bufptr (8 * fromSNat n) (2 ^ 63) + pokeByteOff @Word64 bufptr 0 (2 ^ 63) - let off1 = fromSNat n * 8 + 8 - loop i + let loop i | i == shapeSize sh = k | otherwise = - serialise t (vec V.! i) bufptr (off1 + i * eltsz) $ + serialise t (vec V.! i) bufptr (8 + i * eltsz) $ loop (i+1) loop 0 (STScal sty, x) -> case sty of @@ -498,9 +538,16 @@ deserialise topty ptr off = return (x, y) STEither a b -> do tag <- peekByteOff @Word8 ptr off - if tag == 0 -- alignment of (a + b) is alignment of (union {a b}) + if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b) then Left <$> deserialise a ptr (off + alignmentSTy topty) else Right <$> deserialise b ptr (off + alignmentSTy topty) + STLEither a b -> do + tag <- peekByteOff @Word8 ptr off + case tag of -- alignment of (union {a b}) is the same as alignment of (a + b) + 0 -> return Nothing + 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) + 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) + _ -> error "Invalid tag value" STMaybe t -> do tag <- peekByteOff @Word8 ptr off if tag == 0 @@ -508,13 +555,12 @@ deserialise topty ptr off = else Just <$> deserialise t ptr (off + alignmentSTy t) STArr n t -> do bufptr <- peekByteOff @(Ptr ()) ptr off - sh <- peekShape bufptr 0 n - refc <- peekByteOff @Word64 bufptr (8 * fromSNat n) + sh <- peekShape ptr (off + 8) n + refc <- peekByteOff @Word64 bufptr 0 when debugRefc $ hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc - let off1 = 8 * fromSNat n + 8 - eltsz = sizeofSTy t - arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (off1 + i * eltsz)) + let eltsz = sizeofSTy t + arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (8 + i * eltsz)) when (refc < 2 ^ 62) $ free bufptr return arr STScal sty -> case sty of @@ -545,17 +591,21 @@ metricsSTy (STEither a b) = let (a1, s1) = metricsSTy a (a2, s2) = metricsSTy b in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned +metricsSTy (STLEither a b) = + let (a1, s1) = metricsSTy a + (a2, s2) = metricsSTy b + in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned metricsSTy (STMaybe t) = let (a, s) = metricsSTy t in (a, a + s) -- the union after the tag byte is aligned -metricsSTy (STArr _ _) = (8, 8) +metricsSTy (STArr n _) = (8, 8 + 8 * fromSNat n) metricsSTy (STScal sty) = case sty of STI32 -> (4, 4) STI64 -> (8, 8) STF32 -> (4, 4) STF64 -> (8, 8) STBool -> (1, 1) -- compiled to uint8_t -metricsSTy (STAccum t) = metricsSTy t +metricsSTy (STAccum t) = metricsSTy (fromSMTy t) pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO () pokeShape ptr off = go . fromSNat @@ -571,7 +621,7 @@ peekShape :: Ptr () -> Int -> SNat n -> IO (Shape n) peekShape ptr off = \case SZ -> return ShNil SS n -> ShCons <$> peekShape ptr off n - <*> (fromIntegral <$> peekByteOff @Int64 ptr (off + (fromSNat n) * 8)) + <*> (fromIntegral <$> peekByteOff @Int64 ptr (off + fromSNat n * 8)) compile' :: SList (Const String) env -> Ex env t -> CompM CExpr compile' env = \case @@ -685,16 +735,51 @@ compile' env = \case <> pure (SAsg retvar e3)))) return (CELit retvar) + ELNil _ t1 t2 -> do + name <- emitStruct (STLEither t1 t2) + return $ CEStruct name [("tag", CELit "0")] + + ELInl _ t e -> do + name <- emitStruct (STLEither (typeOf e) t) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "1"), ("l", e1)] + + ELInr _ t e -> do + name <- emitStruct (STLEither t (typeOf e)) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "2"), ("r", e1)] + + ELCase _ e a b c -> do + let STLEither t1 t2 = typeOf e + e1 <- compile' env e + var <- genName + (e2, stmts2) <- scope $ compile' env a + (e3, stmts3) <- scope $ compile' (Const (var ++ ".l") `SCons` env) b + (e4, stmts4) <- scope $ compile' (Const (var ++ ".r") `SCons` env) c + ((), stmtsRel1) <- scope $ incrementVarAlways "lcase1" Decrement t1 (var ++ ".l") + ((), stmtsRel2) <- scope $ incrementVarAlways "lcase2" Decrement t2 (var ++ ".r") + retvar <- genName + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) + <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) + (stmts2 <> pure (SAsg retvar e2)) + (pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "1")) + (stmts3 <> stmtsRel1 <> pure (SAsg retvar e3)) + (stmts4 <> stmtsRel2 <> pure (SAsg retvar e4)))))) + return (CELit retvar) + EConstArr _ n t (Array sh vec) -> do - strname <- emitStruct (STArr n (STScal t)) + (strname, bufstrname) <- emitArrStruct (STArr n (STScal t)) tldname <- genName' "carraybuf" -- Give it a refcount of _half_ the size_t max, so that it can be -- incremented and decremented at will and will "never" reach anything -- where something happens - emitTLD $ "static " ++ strname ++ "_buf " ++ tldname ++ " = " ++ - "(" ++ strname ++ "_buf){.sh = {" ++ intercalate "," (map show (shapeToList sh)) ++ "}, " ++ - ".refc = (size_t)1<<63, .xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};" - return (CEStruct strname [("buf", CEAddrOf (CELit tldname))]) + emitTLD $ "static " ++ bufstrname ++ " " ++ tldname ++ " = " ++ + "(" ++ bufstrname ++ "){.refc = (size_t)1<<63, " ++ + ".xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};" + return (CEStruct strname + [("buf", CEAddrOf (CELit tldname)) + ,("sh", CELit ("{" ++ intercalate "," (map show (shapeToList sh)) ++ "}"))]) EBuild _ n esh efun -> do shname <- compileAssign "sh" env esh @@ -709,7 +794,7 @@ compile' env = \case emit $ SBlock $ pure (SVarDecl False "size_t" linivar (CELit "0")) <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0") - (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".buf->sh")) (CELit (show dimidx)))) + (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".sh")) (CELit (show dimidx)))) | (ivar, dimidx) <- zip ivars [0::Int ..]] (pure (SVarDecl True (repSTy (typeOf esh)) idxargname (shapeTupFromLitVars n ivars)) @@ -718,6 +803,15 @@ compile' env = \case return (CELit arrname) + -- TODO: actually generate decent code here + EMap _ e1 e2 -> do + let STArr n _ = typeOf e2 + compile' env $ + elet e2 $ + EBuild ext n (EShape ext (evar IZ)) $ + elet (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e1 + EFold1Inner _ commut efun ex0 earr -> do let STArr (SS n) t = typeOf earr @@ -734,12 +828,11 @@ compile' env = \case -- unexpected. But it's exactly what we want, so we do it anyway. emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname) - resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) - [CELit (arrname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) (compileArrShapeComponents n arrname) lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname - (CELit (arrname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) ((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name @@ -748,22 +841,26 @@ compile' env = \case -- kvar <- if vecwid > 1 then genName' "k" else return "" accvar <- genName' "tot" + pairvar <- genName' "pair" -- function input + (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun + let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]" - (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit + pairstrname <- emitStruct (STPair t t) emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ pure (SVarDecl False (repSTy t) accvar (CELit x0name)) <> x0incrStmts -- we're copying x0 here - <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ - -- The combination function will consume the array element - -- and the accumulator. The accumulator is replaced by - -- what comes out of the function anyway, so that's - -- fine, but we do need to increment the array element. - arreltIncrStmts - <> funStmts - <> pure (SAsg accvar funres)) + <> pure (SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) + <> funStmts + <> pure (SAsg accvar funres)) <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) incrementVarAlways "foldx0" Decrement t x0name @@ -781,12 +878,11 @@ compile' env = \case -- This n is one less than the shape of the thing we're querying, like EFold1Inner. emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) - [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) (compileArrShapeComponents n argname) lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname - (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) let vecwid = 8 :: Int ivar <- genName' "i" @@ -833,8 +929,7 @@ compile' env = \case resname <- allocArray "repl1i" Malloc "rep" (SS n) t (Just (CEBinop (CELit shszname) "*" (CELit lenname))) - ([CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] - ++ [CELit lenname]) + (compileArrShapeComponents n argname ++ [CELit lenname]) ivar <- genName' "i" jvar <- genName' "j" @@ -851,6 +946,149 @@ compile' env = \case EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e + EReshape _ dim esh earg -> do + let STArr origDim eltty = typeOf earg + strname <- emitStruct (STArr dim eltty) + + shname <- compileAssign "reshsh" env esh + arrname <- compileAssign "resharg" env earg + + when emitChecks $ do + emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname)))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++ + printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++ + printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;") + mempty + + return (CEStruct strname + [("buf", CEProj (CELit arrname) "buf") + ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))]) + + -- TODO: actually generate decent code here + EZip _ e1 e2 -> do + let STArr n _ = typeOf e1 + compile' env $ + elet e1 $ + elet (weakenExpr WSink e2) $ + EBuild ext n (EShape ext (evar (IS IZ))) $ + EPair ext (EIdx ext (evar (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) + (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) + + EFold1InnerD1 _ commut efun ex0 earr -> do + let STArr (SS n) t = typeOf earr + STPair _ bty = typeOf efun + + x0name <- compileAssign "foldd1x0" env ex0 + arrname <- compileAssign "foldd1arr" env earr + + zeroRefcountCheck (typeOf earr) "fold1iD1" arrname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname) + storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname) + + ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "tot" + pairvar <- genName' "pair" -- function input + (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar + arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]" + funresvar <- genName' "res" + ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit + + pairstrname <- emitStruct (STPair t t) + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t) accvar (CELit x0name)) + <> x0incrStmts -- we're copying x0 here + <> pure (SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd1x0" Decrement t x0name + incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname + + strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty)) + return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)]) + + EFold1InnerD2 _ commut efun estores ectg -> do + let STArr n t2 = typeOf ectg + STArr _ bty = typeOf estores + + storesname <- compileAssign "foldd2stores" env estores + ctgname <- compileAssign "foldd2ctg" env ectg + + zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname) + outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname) + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "acc" + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar + storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]" + ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]" + (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun + funresvar <- genName' "res" + ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit + ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit)) + <> ctgeltIncrStmts + -- we need to loop in reverse here, but we let jvar run in the + -- forward direction so that we can use SLoop. Note jvar is + -- reversed in eltidx above + <> pure (SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the accumulator + -- and the stores element. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the stores element. + storeseltIncrStmts + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname + incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname + + strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2)) + return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)]) + EConst _ t x -> return $ CELit $ compileScal True t x EIdx0 _ e -> do @@ -876,7 +1114,7 @@ compile' env = \case when emitChecks $ forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) -> emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||" - (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".buf->sh[" ++ show i ++ "]"))))) + (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".sh[" ++ show i ++ "]"))))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++ arrname ++ ".buf); return false;") @@ -919,6 +1157,8 @@ compile' env = \case maybe (return ()) ($ name2) mfun2 return (CELit name) + ERecompute _ e -> compile' env e + EWith _ t e1 e2 -> do actyname <- emitStruct (STAccum t) name1 <- compileAssign "" env e1 @@ -926,179 +1166,157 @@ compile' env = \case zeroRefcountCheck (typeOf e1) "with" name1 emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")" - mcopy <- copyForWriting (CHAD.d2 t) name1 + mcopy <- copyForWriting t name1 accname <- genName' "accum" emit $ SVarDecl False actyname accname - (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (CHAD.d2 t)))])]) - emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy) + (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])]) + emit $ SAsg (accname++".buf->ac") (fromMaybe (CELit name1) mcopy) emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")." e2' <- compile' (Const accname `SCons` env) e2 resname <- genName' "acret" - emit $ SVarDecl True (repSTy (CHAD.d2 t)) resname (CELit (accname++".buf->ac")) + emit $ SVarDecl True (repSTy (fromSMTy t)) resname (CELit (accname++".buf->ac")) emit $ SVerbatim $ "free_instr(" ++ accname ++ ".buf);" - rettyname <- emitStruct (STPair (typeOf e2) (CHAD.d2 t)) + rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t)) return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)] - EAccum _ t prj eidx eval eacc -> do - nameidx <- compileAssign "acidx" env eidx - nameval <- compileAssign "acval" env eval - - -- Generate the variable manually because this one has to be non-const. - eacc' <- compile' env eacc - nameacc <- genName' "acac" - emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc' + EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do + let -- Add a value (s) into an existing accumulation value (d). If a sparse + -- component of d is encountered, s is copied there. + add :: SMTy a -> String -> String -> CompM () + add SMTNil _ _ = return () + add (SMTPair t1 t2) d s = do + add t1 (d++".a") (s++".a") + add t2 (d++".b") (s++".b") + add (SMTLEither t1 t2) d s = do + ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTLEither t1 t2)) s + ((), stmts1) <- scope $ add t1 (d++".l") (s++".l") + ((), stmts2) <- scope $ add t2 (d++".r") (s++".r") + emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) + (pure (SAsg d (CELit s)) + <> srcIncrStmts) + ((if emitChecks + then pure (SIf (CEBinop (CEBinop (CELit (s++".tag")) "!=" (CELit "0")) + "&&" + (CEBinop (CELit (s++".tag")) "!=" (CELit (d++".tag")))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum add leither with different tags " ++ + "(dest %d, src %d)\\n\", (int)" ++ d ++ ".tag, (int)" ++ s ++ ".tag); " ++ + "return false;") + mempty) + else mempty) + -- note: s may have tag 0 + <> pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) + stmts1 + (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "2")) + stmts2 mempty)))) + add (SMTMaybe t1) d s = do + ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTMaybe t1)) s + ((), stmts1) <- scope $ add t1 (d++".j") (s++".j") + emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) + (pure (SAsg d (CELit s)) + <> srcIncrStmts) + (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) stmts1 mempty)) + add (SMTArr n t1) d s = do + when emitChecks $ do + let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" + forM_ [0 .. fromSNat n - 1] $ \j -> do + emit $ SIf (CEBinop (CELit (s ++ ".sh[" ++ show j ++ "]")) + "!=" + (CELit (d ++ ".sh[" ++ show j ++ "]"))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++ + "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++ + d ++ ".buf" ++ + concat [", " ++ d ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + ", " ++ s ++ ".buf" ++ + concat [", " ++ s ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + "); " ++ + "return false;") + mempty - let -- Expects a variable reference to a value of type @D2 a@. - setZero :: STy a -> String -> CompM () - setZero STNil _ = return () - setZero STPair{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Pair (D2 a) (D2 b)) - setZero STEither{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Either (D2 a) (D2 b)) - setZero STMaybe{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (D2 a) - setZero STArr{} v = emit $ SAsg (v++".tag") (CELit "0") -- Maybe (Arr n (D2 a)) - setZero (STScal sty) v = case sty of - STI32 -> return () -- Nil - STI64 -> return () -- Nil - STF32 -> emit $ SAsg v (CELit "0.0f") - STF64 -> emit $ SAsg v (CELit "0.0") - STBool -> return () -- Nil - setZero STAccum{} _ = error "Compile: setZero: nested accumulators unsupported" + shsizename <- genName' "acshsz" + emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s) + ivar <- genName' "i" + ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]") + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) + stmts1 + add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";" - initD2Pair :: STy a -> STy b -> String -> CompM () - initD2Pair a b v = do -- Maybe (Pair (D2 a) (D2 b)) - ((), stmts1) <- scope $ setZero a (v++".j.a") - ((), stmts2) <- scope $ setZero b (v++".j.b") - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmts1 <> stmts2) - mempty + let -- | Dereference an accumulation value and add a given value to that + -- position. Sparse components encountered along the way are + -- initialised before proceeding downwards. + -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there) + accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM () + accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend - initD2Either :: STy a -> STy b -> String -> Either () () -> CompM () - initD2Either a b v side = do -- Maybe (Either (D2 a) (D2 b)) - ((), stmts) <- case side of - Left () -> scope $ setZero a (v++".j.l") - Right () -> scope $ setZero b (v++".j.r") - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmts) - mempty + accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend + accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend - initD2Maybe :: STy a -> String -> CompM () - initD2Maybe a v = do -- Maybe (D2 a) - ((), stmts) <- scope $ setZero a (v++".j") - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) <> stmts) - mempty + accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef ta prj' (v++".l") i addend + accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tb prj' (v++".r") i addend - -- mind: this has to traverse the D2 of these things, and it also has to - -- initialise data structures that are still sparse in the accumulator. - let accumRef :: STy a -> SAcPrj p a b -> String -> String -> CompM String - accumRef _ SAPHere v _ = pure v - accumRef (STPair ta tb) (SAPFst prj') v i = do - initD2Pair ta tb v - accumRef ta prj' (v++".j.a") i - accumRef (STPair ta tb) (SAPSnd prj') v i = do - initD2Pair ta tb v - accumRef tb prj' (v++".j.b") i - accumRef (STEither ta tb) (SAPLeft prj') v i = do - initD2Either ta tb v (Left ()) - accumRef ta prj' (v++".j.l") i - accumRef (STEither ta tb) (SAPRight prj') v i = do - initD2Either ta tb v (Right ()) - accumRef tb prj' (v++".j.r") i - accumRef (STMaybe tj) (SAPJust prj') v i = do - initD2Maybe tj v - accumRef tj prj' (v++".j") i - accumRef (STArr n t') (SAPArrIdx prj' _) v i = do - (newarrName, newarrStmts) <- scope $ allocArray "accumRef" Calloc "prjarr" n t' Nothing (indexTupleComponents n (i++".a.b")) - emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) - (pure (SAsg (v++".tag") (CELit "1")) - <> newarrStmts - <> pure (SAsg (v++".j") (CELit newarrName))) - mempty + accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tj prj' (v++".j") i addend + accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do when emitChecks $ do let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" - forM_ (zip3 [0::Int ..] - (indexTupleComponents n (i++".a.a")) - (indexTupleComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do + forM_ (zip [0::Int ..] + (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do let a .||. b = CEBinop a "||" b emit $ SIf (CEBinop ixcomp "<" (CELit "0") .||. - CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]"))) - .||. - CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".j.buf->sh[" ++ show j ++ "]")))) + CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".sh[" ++ show j ++ "]")))) (pure $ SVerbatim $ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ - "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ - v ++ ".j.buf" ++ - concat [", " ++ v ++ ".j.buf->sh[" ++ show k ++ "]" | k <- [0 .. fromSNat n - 1]] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.b")] ++ + "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++ + v ++ ".buf" ++ + concat [", " ++ v ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++ "); " ++ "return false;") mempty - accumRef t' prj' (v++".j.buf->xs[" ++ printCExpr 0 (toLinearIdx n (v++".j") (i++".a.a")) "]") (i++".b") + accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend - -- mind: this has to add the D2 of these things, and it also has to - -- initialise data structures that are still sparse in the accumulator. - let add :: STy a -> String -> String -> CompM () - add STNil _ _ = return () - add (STPair t1 t2) d s = do - ((), stmts1) <- scope $ add t1 (d++".j.a") (s++".j.a") - ((), stmts2) <- scope $ add t2 (d++".j.b") (s++".j.b") - emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) - (pure (SAsg d (CELit s))) - (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) - (stmts1 <> stmts2) - mempty)) - add (STEither t1 t2) d s = do - ((), stmts1) <- scope $ add t1 (d++".j.l") (s++".j.l") - ((), stmts2) <- scope $ add t2 (d++".j.r") (s++".j.r") - emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) - (pure (SAsg d (CELit s))) - (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) - (pure (SAsg (d++".j.tag") (CELit (s++".j.tag"))) - <> pure (SIf (CEBinop (CELit (s++".j.tag")) "==" (CELit "0")) - stmts1 stmts2)) - mempty)) - add (STMaybe t1) d s = do - ((), stmts1) <- scope $ add t1 (d++".j") (s++".j") - emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) - (pure (SAsg d (CELit s))) - (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) - (pure (SAsg (d++".tag") (CELit "1")) <> stmts1) - mempty)) - add (STArr n t1) d s = do - shsizename <- genName' "acshsz" - ivar <- genName' "i" - ((), stmts1) <- scope $ add t1 (d++".j.buf->xs["++ivar++"]") (s++".j.buf->xs["++ivar++"]") - ((), stmtsDecr) <- scope $ incrementVarAlways "accumarr" Decrement (STArr n (CHAD.d2 t1)) (s++".j") - emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) - (pure (SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) - (pure (SAsg d (CELit s))) - (pure (SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j"))) - -- TODO: emit check here for the source being either equal in shape to the destination - <> pure (SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) - stmts1) - <> stmtsDecr))) - mempty - add (STScal sty) d s = case sty of - STI32 -> return () - STI64 -> return () - STF32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" - STF64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" - STBool -> return () - add (STAccum _) _ _ = error "Compile: nested accumulators unsupported" + nameidx <- compileAssign "acidx" env eidx + nameval <- compileAssign "acval" env eval + nameacc <- compileAssign "acac" env eacc emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")" - dest <- accumRef t prj (nameacc++".buf->ac") nameidx - add (acPrjTy prj t) dest nameval + accumRef t prj (nameacc++".buf->ac") nameidx nameval emit $ SVerbatim $ "// compile EAccum end" + incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval + return $ CEStruct (repSTy STNil) [] + EAccum{} -> + error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)" + EError _ t s -> do let padleft len c s' = replicate (len - length s) c ++ s' escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c] @@ -1111,9 +1329,10 @@ compile' env = \case name <- emitStruct t return $ CEStruct name [] - EZero{} -> error "Compile: monoid operations should have been eliminated" - EPlus{} -> error "Compile: monoid operations should have been eliminated" - EOneHot{} -> error "Compile: monoid operations should have been eliminated" + EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" EIdx1{} -> error "Compile: not implemented: EIdx1" @@ -1144,6 +1363,7 @@ data ArrayTree = ATArray (Some SNat) (Some STy) -- ^ we've arrived at an array | ATNoop -- ^ don't do anything here | ATProj String ArrayTree -- ^ descend one field deeper | ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second + | ATCond3Tag ArrayTree ArrayTree ArrayTree -- ^ if tag is: 0, 1, 2 | ATBoth ArrayTree ArrayTree -- ^ do both these paths smartATProj :: String -> ArrayTree -> ArrayTree @@ -1154,6 +1374,10 @@ smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree smartATCondTag ATNoop ATNoop = ATNoop smartATCondTag t t' = ATCondTag t t' +smartATCond3Tag :: ArrayTree -> ArrayTree -> ArrayTree -> ArrayTree +smartATCond3Tag ATNoop ATNoop ATNoop = ATNoop +smartATCond3Tag t1 t2 t3 = ATCond3Tag t1 t2 t3 + smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree smartATBoth ATNoop t = t smartATBoth t ATNoop = t @@ -1165,6 +1389,9 @@ makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a)) (smartATProj "b" (makeArrayTree b)) makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a)) (smartATProj "r" (makeArrayTree b)) +makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop + (smartATProj "l" (makeArrayTree a)) + (smartATProj "r" (makeArrayTree b)) makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t)) makeArrayTree (STArr n t) = ATArray (Some n) (Some t) makeArrayTree (STScal _) = ATNoop @@ -1204,6 +1431,15 @@ incrementVar' marker inc path (ATCondTag t1 t2) = do ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1 ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2 emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) stmts1 stmts2 +incrementVar' marker inc path (ATCond3Tag t1 t2 t3) = do + ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1 + ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2 + ((), stmts3) <- scope $ incrementVar' (marker++".t3") inc path t3 + emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "1")) + stmts2 + (pure (SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "2")) + stmts3 + stmts1)) incrementVar' marker inc path (ATBoth t1 t2) = incrementVar' (marker++".1") inc path t1 >> incrementVar' (marker++".2") inc path t2 toLinearIdx :: SNat n -> String -> String -> CExpr @@ -1211,21 +1447,21 @@ toLinearIdx SZ _ _ = CELit "0" toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b") toLinearIdx (SS n) arrvar idxvar = CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a")) - "*" (CEIndex (CELit (arrvar ++ ".buf->sh")) (CELit (show (fromSNat n))))) + "*" (CEIndex (CELit (arrvar ++ ".sh")) (CELit (show (fromSNat n))))) "+" (CELit (idxvar ++ ".b")) -- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr -- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) [] -- fromLinearIdx (SS n) arrvar idxvar = do -- name <- genName --- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]"))) +-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".sh[" ++ show (fromSNat n) ++ "]"))) -- _ data AllocMethod = Malloc | Calloc deriving (Show) -- | The shape must have the outer dimension at the head (and the inner dimension on the right). -allocArray :: String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String +allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String allocArray marker method nameBase rank eltty mshsz shape = do when (length shape /= fromSNat rank) $ error "allocArray: shape does not match rank" @@ -1240,9 +1476,8 @@ allocArray marker method nameBase rank eltty mshsz shape = do (CEBinop shsz "*" (CELit (show (sizeofSTy eltty)))) emit $ SVarDecl True strname arrname $ CEStruct strname [("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr] - Calloc -> CECall "calloc_instr" [nbytesExpr])] - forM_ (zip shape [0::Int ..]) $ \(dim, i) -> - emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim + Calloc -> CECall "calloc_instr" [nbytesExpr]) + ,("sh", CELit ("{" ++ intercalate "," [printCExpr 0 dim "" | dim <- shape] ++ "}"))] emit $ SAsg (arrname ++ ".buf->refc") (CELit "1") when debugRefc $ emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);" @@ -1253,14 +1488,16 @@ compileShapeQuery SZ _ = CEStruct (repSTy STNil) [] compileShapeQuery (SS n) var = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", compileShapeQuery n var) - ,("b", CEIndex (CELit (var ++ ".buf->sh")) (CELit (show (fromSNat n))))] + ,("b", CEIndex (CELit (var ++ ".sh")) (CELit (show (fromSNat n))))] -- | Takes a variable name for the array, not the buffer. compileArrShapeSize :: SNat n -> String -> CExpr -compileArrShapeSize SZ _ = CELit "1" -compileArrShapeSize n var = - foldl1' (\a b -> CEBinop a "*" b) [CELit (var ++ ".buf->sh[" ++ show i ++ "]") - | i <- [0 .. fromSNat n - 1]] +compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var) + +-- | Takes a variable name for the array, not the buffer. +compileArrShapeComponents :: SNat n -> String -> [CExpr] +compileArrShapeComponents n var = + [CELit (var ++ ".sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] indexTupleComponents :: SNat n -> String -> [CExpr] indexTupleComponents = \n var -> map CELit (toList (go n var)) @@ -1279,6 +1516,9 @@ shapeTupFromLitVars = \n -> go n . reverse go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)] go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond" +prodExpr :: [CExpr] -> CExpr +prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") + compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr compileOpGeneral op e1 = do let unary cop = return @CompM $ CECall cop [e1] @@ -1347,12 +1587,11 @@ compileExtremum nameBase opName operator env e = do -- unexpected. But it's exactly what we want, so we do it anyway. emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) - [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) (compileArrShapeComponents n argname) lenname <- genName' "n" emit $ SVarDecl True (repSTy tIx) lenname - (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }" @@ -1375,47 +1614,47 @@ compileExtremum nameBase opName operator env e = do -- | If this returns Nothing, there was nothing to copy because making a simple -- value copy in C already makes it suitable to write to. -copyForWriting :: STy t -> String -> CompM (Maybe CExpr) +copyForWriting :: SMTy t -> String -> CompM (Maybe CExpr) copyForWriting topty var = case topty of - STNil -> return Nothing + SMTNil -> return Nothing - STPair a b -> do + SMTPair a b -> do e1 <- copyForWriting a (var ++ ".a") e2 <- copyForWriting b (var ++ ".b") case (e1, e2) of (Nothing, Nothing) -> return Nothing - _ -> return $ Just $ CEStruct (repSTy topty) + _ -> return $ Just $ CEStruct toptyname [("a", fromMaybe (CELit (var++".a")) e1) ,("b", fromMaybe (CELit (var++".b")) e2)] - STEither a b -> do + SMTLEither a b -> do (e1, stmts1) <- scope $ copyForWriting a (var ++ ".l") (e2, stmts2) <- scope $ copyForWriting b (var ++ ".r") case (e1, e2) of (Nothing, Nothing) -> return Nothing _ -> do name <- genName - emit $ SVarDeclUninit (repSTy topty) name + emit $ SVarDeclUninit toptyname name emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) (stmts1 - <> pure (SAsg name (CEStruct (repSTy topty) + <> pure (SAsg name (CEStruct toptyname [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)]))) (stmts2 - <> pure (SAsg name (CEStruct (repSTy topty) + <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)]))) return (Just (CELit name)) - STMaybe t -> do + SMTMaybe t -> do (e1, stmts1) <- scope $ copyForWriting t (var ++ ".j") case e1 of Nothing -> return Nothing Just e1' -> do name <- genName - emit $ SVarDeclUninit (repSTy topty) name + emit $ SVarDeclUninit toptyname name emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) - (pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "0")]))) + (pure (SAsg name (CEStruct toptyname [("tag", CELit "0")]))) (stmts1 - <> pure (SAsg name (CEStruct (repSTy topty) [("tag", CELit "1"), ("j", e1')]))) + <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("j", e1')]))) return (Just (CELit name)) -- If there are no nested arrays, we know that a refcount of 1 means that the @@ -1423,53 +1662,51 @@ copyForWriting topty var = case topty of -- nesting we'd have to check the refcounts of all the nested arrays _too_; -- let's not do that. Furthermore, no sub-arrays means that the whole thing -- is flat, and we can just memcpy if necessary. - STArr n t | not (hasArrays t) -> do + SMTArr n t | not (typeHasArrays (fromSMTy t)) -> do name <- genName shszname <- genName' "shsz" - emit $ SVarDeclUninit (repSTy (STArr n t)) name + emit $ SVarDeclUninit toptyname name when debugShapes $ do let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" emit $ SVerbatim $ "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++ - concat [", " ++ var ++ ".buf->sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ var ++ ".sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++ ");" emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1")) (pure (SAsg name (CELit var))) (let shbytes = fromSNat n * 8 - databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t))) + databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t)))) totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes in BList [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) - ,SAsg name (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc_instr" [totalbytes])]) - ,SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ - show shbytes ++ ");" + ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) + ,SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" ,SAsg (name ++ ".buf->refc") (CELit "1") ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++ printCExpr 0 databytes ");"]) return (Just (CELit name)) - STArr n t -> do + SMTArr n t -> do shszname <- genName' "shsz" emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) let shbytes = fromSNat n * 8 - databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy t))) + databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t)))) totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes name <- genName - emit $ SVarDecl False (repSTy (STArr n t)) name - (CEStruct (repSTy (STArr n t)) [("buf", CECall "malloc_instr" [totalbytes])]) - emit $ SVerbatim $ "memcpy(" ++ name ++ ".buf->sh, " ++ var ++ ".buf->sh, " ++ - show shbytes ++ ");" + emit $ SVarDecl False toptyname name + (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) + emit $ SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" emit $ SAsg (name ++ ".buf->refc") (CELit "1") -- put the arrays in variables to cut short the not-quite-var chain dstvar <- genName' "cpydst" - emit $ SVarDecl True (repSTy t ++ " *") dstvar (CELit (name ++ ".buf->xs")) + emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") dstvar (CELit (name ++ ".buf->xs")) srcvar <- genName' "cpysrc" - emit $ SVarDecl True (repSTy t ++ " *") srcvar (CELit (var ++ ".buf->xs")) + emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") srcvar (CELit (var ++ ".buf->xs")) ivar <- genName' "i" @@ -1484,9 +1721,10 @@ copyForWriting topty var = case topty of return (Just (CELit name)) - STScal _ -> return Nothing + SMTScal _ -> return Nothing - STAccum _ -> error "Compile: Nested accumulators not supported" + where + toptyname = repSTy (fromSMTy topty) zeroRefcountCheck :: STy t -> String -> String -> CompM () zeroRefcountCheck toptyp opname topvar = @@ -1505,6 +1743,14 @@ zeroRefcountCheck toptyp opname topvar = go (STEither a b) path = do (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2 + go (STLEither a b) path = do + (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) + return $ pure $ + SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) + s1 + (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) + s2 + mempty)) go (STMaybe a) path = do ss <- go a (path++".j") return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty diff --git a/src/Compile/Exec.hs b/src/CHAD/Compile/Exec.hs index d708fc0..ffe5661 100644 --- a/src/Compile/Exec.hs +++ b/src/CHAD/Compile/Exec.hs @@ -1,6 +1,5 @@ {-# LANGUAGE LambdaCase #-} -{-# LANGUAGE TupleSections #-} -module Compile.Exec ( +module CHAD.Compile.Exec ( KernelLib, buildKernel, callKernelFun, @@ -11,8 +10,6 @@ module Compile.Exec ( import Control.Monad (when) import Data.IORef -import qualified Data.Map.Strict as Map -import Data.Map.Strict (Map) import Foreign (Ptr) import Foreign.Ptr (FunPtr) import System.Directory (removeDirectoryRecursive) @@ -30,10 +27,10 @@ debug :: Bool debug = False -- The IORef wrapper is required for the finalizer to attach properly (see the 'Weak' docs) -data KernelLib = KernelLib !(IORef (Map String (FunPtr (Ptr () -> IO ())))) +data KernelLib = KernelLib !(IORef (FunPtr (Ptr () -> IO ()))) -buildKernel :: String -> [String] -> IO KernelLib -buildKernel csource funnames = do +buildKernel :: String -> String -> IO (KernelLib, String) +buildKernel csource funname = do template <- (++ "/tmp.chad.") <$> getTempDir path <- mkdtemp template @@ -44,7 +41,9 @@ buildKernel csource funnames = do ,"-o", outso, "-" ,"-Wall", "-Wextra" ,"-Wno-unused-variable", "-Wno-unused-but-set-variable" - ,"-Wno-unused-parameter", "-Wno-unused-function"] + ,"-Wno-unused-parameter", "-Wno-unused-function" + ,"-Wno-alloc-size-larger-than" -- ideally we'd keep this, but gcc reports false positives + ,"-Wno-maybe-uninitialized"] -- maximum1i goes out of range if its input is empty, yes, don't complain (ec, gccStdout, gccStderr) <- readProcessWithExitCode "gcc" args csource -- Print the source before the GCC output. @@ -52,11 +51,6 @@ buildKernel csource funnames = do ExitSuccess -> return () ExitFailure{} -> hPutStrLn stderr $ "[chad] Kernel compilation failed! Source: <<<\n" ++ lineNumbers csource ++ ">>>" - when (not (null gccStdout)) $ - hPutStrLn stderr $ "[chad] Kernel compilation: GCC stdout: <<<\n" ++ gccStdout ++ ">>>" - when (not (null gccStderr)) $ - hPutStrLn stderr $ "[chad] Kernel compilation: GCC stderr: <<<\n" ++ gccStderr ++ ">>>" - case ec of ExitSuccess -> return () ExitFailure{} -> do @@ -69,22 +63,21 @@ buildKernel csource funnames = do removeDirectoryRecursive path -- we keep a reference anyway because we have the file open now - ptrs <- Map.fromList <$> sequence [(name,) <$> dlsym dl name | name <- funnames] - ref <- newIORef ptrs + ref <- newIORef =<< dlsym dl funname _ <- mkWeakIORef ref (do numLeft <- atomicModifyIORef' numLoadedCounter (\n -> (n-1, n-1)) when debug $ hPutStrLn stderr $ "[chad] unloading kernel " ++ path ++ " (" ++ show numLeft ++ " left)" dlclose dl) - return (KernelLib ref) + return (KernelLib ref, gccStdout ++ (if null gccStdout then "" else "\n") ++ gccStderr) foreign import ccall "dynamic" wrapKernelFun :: FunPtr (Ptr () -> IO ()) -> Ptr () -> IO () -- Ensure that keeping a reference to the returned function also keeps the 'KernelLib' alive {-# NOINLINE callKernelFun #-} -callKernelFun :: String -> KernelLib -> Ptr () -> IO () -callKernelFun key (KernelLib ref) arg = do - mp <- readIORef ref - wrapKernelFun (mp Map.! key) arg +callKernelFun :: KernelLib -> Ptr () -> IO () +callKernelFun (KernelLib ref) arg = do + ptr <- readIORef ref + wrapKernelFun ptr arg getTempDir :: IO FilePath getTempDir = diff --git a/src/Data.hs b/src/CHAD/Data.hs index e86aaa6..8c7605c 100644 --- a/src/Data.hs +++ b/src/CHAD/Data.hs @@ -8,16 +8,17 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module Data (module Data, (:~:)(Refl)) where +module CHAD.Data (module CHAD.Data, (:~:)(Refl), If) where import Data.Functor.Product import Data.GADT.Compare import Data.GADT.Show import Data.Some +import Data.Type.Bool (If) import Data.Type.Equality import Unsafe.Coerce (unsafeCoerce) -import Lemmas (Append) +import CHAD.Lemmas (Append) data Dict c where @@ -184,3 +185,8 @@ instance Applicative Bag where instance Semigroup (Bag t) where (<>) = BTwo instance Monoid (Bag t) where mempty = BNone + +data SBool b where + SF :: SBool False + ST :: SBool True +deriving instance Show (SBool b) diff --git a/src/Data/VarMap.hs b/src/CHAD/Data/VarMap.hs index 9c10421..a0d7617 100644 --- a/src/Data/VarMap.hs +++ b/src/CHAD/Data/VarMap.hs @@ -1,10 +1,11 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} -module Data.VarMap ( +module CHAD.Data.VarMap ( VarMap, empty, insert, @@ -20,16 +21,16 @@ module Data.VarMap ( import Prelude hiding (lookup) -import qualified Data.Map.Strict as Map +import Data.Map.Strict qualified as Map import Data.Map.Strict (Map) import Data.Maybe (mapMaybe) import Data.Some -import qualified Data.Vector.Storable as VS +import Data.Vector.Storable qualified as VS import Unsafe.Coerce -import AST.Env -import AST.Types -import AST.Weaken +import CHAD.AST.Env +import CHAD.AST.Types +import CHAD.AST.Weaken type role VarMap _ nominal -- ensure that 'env' is not phantom @@ -74,7 +75,7 @@ subMap :: Eq k => Subenv env env' -> VarMap k env -> VarMap k env' subMap subenv = let bools = let loop :: Subenv env env' -> [Bool] loop SETop = [] - loop (SEYes sub) = True : loop sub + loop (SEYesR sub) = True : loop sub loop (SENo sub) = False : loop sub in VS.fromList $ loop subenv newIndices = VS.init $ VS.scanl' (\n b -> if b then n + 1 else n) (0 :: Int) bools @@ -89,7 +90,7 @@ superMap :: Eq k => Subenv env env' -> VarMap k env' -> VarMap k env superMap subenv = let loop :: Subenv env env' -> Int -> [Int] loop SETop _ = [] - loop (SEYes sub) i = i : loop sub (i+1) + loop (SEYesR sub) i = i : loop sub (i+1) loop (SENo sub) i = loop sub (i+1) newIndices = VS.fromList $ loop subenv 0 diff --git a/src/CHAD/Drev.hs b/src/CHAD/Drev.hs new file mode 100644 index 0000000..bfa964b --- /dev/null +++ b/src/CHAD/Drev.hs @@ -0,0 +1,1581 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeData #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +-- I want to bring various type variables in scope using type annotations in +-- patterns, but I don't want to have to mention all the other type parameters +-- of the types in question as well then. Partial type signatures (with '_') are +-- useful here. +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS -Wno-partial-type-signatures #-} +module CHAD.Drev ( + drev, + freezeRet, + CHADConfig(..), + defaultConfig, + Storage(..), + Descr(..), + Select, +) where + +import Data.Functor.Const +import Data.Some +import Data.Type.Equality (type (==), testEquality) + +import CHAD.Analysis.Identity (ValId(..), validSplitEither) +import CHAD.AST +import CHAD.AST.Bindings +import CHAD.AST.Count +import CHAD.AST.Env +import CHAD.AST.Sparse +import CHAD.AST.Weaken.Auto +import CHAD.Data +import CHAD.Data.VarMap qualified as VarMap +import CHAD.Data.VarMap (VarMap) +import CHAD.Drev.Accum +import CHAD.Drev.EnvDescr +import CHAD.Drev.Types +import CHAD.Lemmas + + +------------------------------ TAPES AND BINDINGS ------------------------------ + +type family Tape binds where + Tape '[] = TNil + Tape (t : ts) = TPair t (Tape ts) + +tapeTy :: SList STy binds -> STy (Tape binds) +tapeTy SNil = STNil +tapeTy (SCons t ts) = STPair t (tapeTy ts) + +bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds + -> binds :> env2 -> Ex env2 (Tape tapebinds) +bindingsCollectTape SNil SETop _ = ENil ext +bindingsCollectTape (t `SCons` binds) (SEYesR sub) w = + EPair ext (EVar ext t (w @> IZ)) + (bindingsCollectTape binds sub (w .> WSink)) +bindingsCollectTape (_ `SCons` binds) (SENo sub) w = + bindingsCollectTape binds sub (w .> WSink) + +-- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds +-- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) +-- bindingsCollectTape' binds sub w +-- | Refl <- lemAppendNil @binds +-- = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env)) + +-- In order from large to small: i.e. in reverse order from what we want, +-- because in a Bindings, the head of the list is the bottom-most entry. +type family TapeUnfoldings binds where + TapeUnfoldings '[] = '[] + TapeUnfoldings (t : ts) = Tape ts : TapeUnfoldings ts + +type family Reverse l where + Reverse '[] = '[] + Reverse (t : ts) = Append (Reverse ts) '[t] + +-- An expression that is always 'snd' +data UnfExpr env t where + UnfExSnd :: STy s -> STy t -> UnfExpr (TPair s t : env) t + +fromUnfExpr :: UnfExpr env t -> Ex env t +fromUnfExpr (UnfExSnd s t) = ESnd ext (EVar ext (STPair s t) IZ) + +-- - A bunch of 'snd' expressions taking us from knowing that there's a +-- 'Tape ts' in the environment (for simplicity assume it's at IZ, we'll fix +-- this in reconstructBindings), to having 'Reverse (TapeUnfoldings ts)' in +-- the environment. +-- - In the extended environment, another bunch of let bindings (these are +-- 'fst' expressions, but no need to know that statically) that project the +-- fsts out of what we introduced above, one for each type in 'ts'. +data Reconstructor env ts = + Reconstructor + (Bindings UnfExpr (Tape ts : env) (Reverse (TapeUnfoldings ts))) + (Bindings Ex (Append (Reverse (TapeUnfoldings ts)) (Tape ts : env)) ts) + +ssnoc :: SList f ts -> f t -> SList f (Append ts '[t]) +ssnoc SNil a = SCons a SNil +ssnoc (SCons t ts) a = SCons t (ssnoc ts a) + +sreverse :: SList f ts -> SList f (Reverse ts) +sreverse SNil = SNil +sreverse (SCons t ts) = ssnoc (sreverse ts) t + +stapeUnfoldings :: SList STy ts -> SList STy (TapeUnfoldings ts) +stapeUnfoldings SNil = SNil +stapeUnfoldings (SCons _ ts) = SCons (tapeTy ts) (stapeUnfoldings ts) + +-- Puts a 'snd' at the top of an unfolder stack and grows the context variable by one. +shiftUnfolder + :: STy t + -> SList STy ts + -> Bindings UnfExpr (Tape ts : env) list + -> Bindings UnfExpr (Tape (t : ts) : env) (Append list '[Tape ts]) +shiftUnfolder newTy ts BTop = BPush BTop (tapeTy ts, UnfExSnd newTy (tapeTy ts)) +shiftUnfolder newTy ts (BPush b (t, UnfExSnd itemTy _)) = + -- Recurse on 'b', and retype the 'snd'. We need to unfold 'b' once in order + -- to expand an 'Append' in the types so that things simplify just enough. + -- We have an equality 'Append binds x1 ~ a : x2', where 'binds' is the list + -- of bindings produced by 'b'. We want to conclude from this that + -- 'binds ~ a : x3' for some 'x3', but GHC will only do that once we know + -- that 'binds ~ y : ys' so that the 'Append' can expand one step, after + -- which 'y ~ a' as desired. The 'case' unfolds 'b' one step. + BPush (shiftUnfolder newTy ts b) (t, case b of BTop -> UnfExSnd itemTy t + BPush{} -> UnfExSnd itemTy t) + +growRecon :: forall env t ts. STy t -> SList STy ts -> Reconstructor env ts -> Reconstructor env (t : ts) +growRecon t ts (Reconstructor unfbs bs) + | Refl <- lemAppendNil @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) + , Refl <- lemAppendAssoc @ts @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) @(Tape (t : ts) : env) + , Refl <- lemAppendAssoc @(Reverse (TapeUnfoldings ts)) @'[Tape ts] @env + = Reconstructor + (shiftUnfolder t ts unfbs) + -- Add a 'fst' at the bottom of the builder stack. + -- First we have to weaken most of 'bs' to skip one more binding in the + -- unfolder stack above it. + (BPush (fst (weakenBindingsE + (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil)) + (WSink :: env :> (Tape (t : ts) : env))) bs)) + (t + ,EFst ext $ EVar ext (tapeTy (SCons t ts)) $ + wSinks @(Tape (t : ts) : env) + (sappend ts + (sappend (sappend (sreverse (stapeUnfoldings ts)) + (SCons (tapeTy ts) SNil)) + SNil)) + @> IZ)) + +buildReconstructor :: SList STy ts -> Reconstructor env ts +buildReconstructor SNil = Reconstructor BTop BTop +buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts) + +-- STRATEGY FOR reconstructBindings +-- +-- binds = [] +-- e : () +-- +-- binds = [c] +-- e : (c, ()) +-- x0 = snd x1 : () +-- y1 = fst e : c +-- +-- binds = [b, c] +-- e : (b, (c, ())) +-- x1 = snd e : (c, ()) +-- x0 = snd x1 : () +-- y1 = fst x1 : c +-- y2 = fst x2 : b +-- +-- binds = [a, b, c] +-- e : (a, (b, (c, ()))) +-- x2 = snd e : (b, (c, ())) +-- x1 = snd x2 : (c, ()) +-- x0 = snd x1 : () +-- y1 = fst x1 : c +-- y2 = fst x2 : b +-- y3 = fst x3 : a + +-- Given that in 'env' we can find a 'Tape binds', i.e. a tuple containing all +-- the things in the list 'binds', we want to create a let stack that extracts +-- all values from that tuple and in effect "restores" the environment +-- described by 'binds'. The idea is that elsewhere, we took a slice of the +-- environment and saved it all in a tuple to be restored later. We +-- incidentally also add a bunch of additional bindings, namely 'Reverse +-- (TapeUnfoldings binds)', so the calling code just has to skip those in +-- whatever it wants to do. +reconstructBindings :: SList STy binds + -> (forall env. Idx env (Tape binds) -> Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) + ,SList STy (Reverse (TapeUnfoldings binds))) +reconstructBindings binds = + (\tape -> let Reconstructor unf build = buildReconstructor binds + in fst $ weakenBindingsE (WIdx tape) + (bconcat (mapBindings fromUnfExpr unf) build) + ,sreverse (stapeUnfoldings binds)) + + +---------------------------------- DERIVATIVES --------------------------------- + +d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) +d1op (OAdd t) e = EOp ext (OAdd t) e +d1op (OMul t) e = EOp ext (OMul t) e +d1op (ONeg t) e = EOp ext (ONeg t) e +d1op (OLt t) e = EOp ext (OLt t) e +d1op (OLe t) e = EOp ext (OLe t) e +d1op (OEq t) e = EOp ext (OEq t) e +d1op ONot e = EOp ext ONot e +d1op OAnd e = EOp ext OAnd e +d1op OOr e = EOp ext OOr e +d1op OIf e = EOp ext OIf e +d1op ORound64 e = EOp ext ORound64 e +d1op OToFl64 e = EOp ext OToFl64 e +d1op (ORecip t) e = EOp ext (ORecip t) e +d1op (OExp t) e = EOp ext (OExp t) e +d1op (OLog t) e = EOp ext (OLog t) e +d1op (OIDiv t) e = EOp ext (OIDiv t) e +d1op (OMod t) e = EOp ext (OMod t) e + +-- | Both primal and dual must be duplicable expressions +data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) + | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)) + +d2op :: SOp a t -> D2Op a t +d2op op = case op of + OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d + OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> + EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) + (EOp ext (OMul t) (EPair ext (EFst ext e) d)) + ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d + OLt t -> Linear $ \_ -> pairZero t + OLe t -> Linear $ \_ -> pairZero t + OEq t -> Linear $ \_ -> pairZero t + ONot -> Linear $ \_ -> ENil ext + OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OIf -> Linear $ \_ -> ENil ext + ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext) + OToFl64 -> Linear $ \_ -> ENil ext + ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) + OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) + OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) + OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + where + pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a))) + pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext)) + (EZero ext (d2M (STScal t)) (ENil ext)) + where + ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r + ziNil STI32 k = k + ziNil STI64 k = k + ziNil STF32 k = k + ziNil STF64 k = k + ziNil STBool k = k + + d2opUnArrangeInt :: SScalTy a + -> (D2s a ~ TScal a => D2Op (TScal a) t) + -> D2Op (TScal a) t + d2opUnArrangeInt ty float = case ty of + STI32 -> Linear $ \_ -> ENil ext + STI64 -> Linear $ \_ -> ENil ext + STF32 -> float + STF64 -> float + STBool -> Linear $ \_ -> ENil ext + + d2opBinArrangeInt :: SScalTy a + -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) + -> D2Op (TPair (TScal a) (TScal a)) t + d2opBinArrangeInt ty float = case ty of + STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + STF32 -> float + STF64 -> float + STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + + floatingD2 :: ScalIsFloating a ~ True + => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r + floatingD2 STF32 k = k + floatingD2 STF64 k = k + + integralD2 :: ScalIsIntegral a ~ True + => SScalTy a -> ((D2s a ~ TNil, ScalIsNumeric a ~ True) => r) -> r + integralD2 STI32 k = k + integralD2 STI64 k = k + +desD1E :: Descr env sto -> SList STy (D1E env) +desD1E = d1e . descrList + +-- d1W :: env :> env' -> D1E env :> D1E env' +-- d1W WId = WId +-- d1W WSink = WSink +-- d1W (WCopy w) = WCopy (d1W w) +-- d1W (WPop w) = WPop (d1W w) +-- d1W (WThen u w) = WThen (d1W u) (d1W w) + +conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) +conv1Idx IZ = IZ +conv1Idx (IS i) = IS (conv1Idx i) + +data Idx2 env sto t + = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) + | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t)) + | Idx2Di (Idx (Select env sto "discr") t) + +conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t +conv2Idx (DPush _ (_, _, SAccum)) IZ = Idx2Ac IZ +conv2Idx (DPush _ (_, _, SMerge)) IZ = Idx2Me IZ +conv2Idx (DPush _ (_, _, SDiscr)) IZ = Idx2Di IZ +conv2Idx (DPush des (_, _, SAccum)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j) + Idx2Me j -> Idx2Me j + Idx2Di j -> Idx2Di j +conv2Idx (DPush des (_, _, SMerge)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac j + Idx2Me j -> Idx2Me (IS j) + Idx2Di j -> Idx2Di j +conv2Idx (DPush des (_, _, SDiscr)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac j + Idx2Me j -> Idx2Me j + Idx2Di j -> Idx2Di (IS j) +conv2Idx DTop i = case i of {} + +opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) +opt2UnSparse = go . opt2 + where + go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) + go (STScal STI32) SpAbsent = \_ -> ENil ext + go (STScal STI64) SpAbsent = \_ -> ENil ext + go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext) + go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext) + go (STScal STBool) SpAbsent = \_ -> ENil ext + go (STScal STF32) SpScal = id + go (STScal STF64) SpScal = id + go STNil _ = \_ -> ENil ext + go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2) + go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary" + + +----------------------------------- SPARSITY ----------------------------------- + +expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a) +expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e +expandSparse t (SpSparse sp) epr e = + EMaybe ext + (EZero ext (d2M t) (d2zeroInfo t epr)) + (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ)) + e +expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr) +expandSparse (STPair t1 t2) (SpPair s1 s2) epr e = + eunPair epr $ \w1 epr1 epr2 -> + eunPair (weakenExpr w1 e) $ \w2 e1 e2 -> + EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1) + (expandSparse t2 s2 (weakenExpr w2 epr2) e2) +expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e = + ELCase ext e + (EZero ext (d2M (STEither t1 t2)) (ENil ext)) + (ECase ext (weakenExpr WSink epr) + (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) + (EError ext (d2 (STEither t1 t2)) "expspa r<-dl")) + (ECase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa l<-dr") + (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e = + ELCase ext e + (EZero ext (d2M (STEither t1 t2)) (ENil ext)) + (ELCase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl") + (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) + (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl")) + (ELCase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr") + (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr") + (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STMaybe t) (SpMaybe s) epr e = + EMaybe ext + (ENothing ext (d2 t)) + (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr + in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ))) + e +expandSparse (STArr _ t) (SpArr s) epr e = + ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e +expandSparse (STScal STF32) SpScal _ e = e +expandSparse (STScal STF64) SpScal _ e = e +expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program" + +subenvPlus :: SBool req1 -> SBool req2 + -> SList SMTy env + -> SubenvS env env1 -> SubenvS env env2 + -> (forall env3. SubenvS env env3 + -> Injection req1 (Tup env1) (Tup env3) + -> Injection req2 (Tup env2) (Tup env3) + -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3)) + -> r) + -> r +-- don't destroy effects! +subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext) + +subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k = + subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl -> + k (SENo sub3) s31 s32 pl + +subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k = + subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl -> + k (SEYes sp1 sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) e1b) + Noinj + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (ESnd ext (EVar ext (typeOf e1) IZ))) +subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k + | Just zero1 <- cheapZero (applySparse sp1 t) = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes sp1 sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) e1b) + (Inj $ \e2 -> EPair ext (inj23 e2) zero1) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (ESnd ext (EVar ext (typeOf e1) IZ))) + | otherwise = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes (SpSparse sp1) sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) (EJust ext e1b)) + (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t)))) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ)))) + +subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k = + subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl -> + k sub3 minj13 minj23 (flip pl) + +subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k = + subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl -> + sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus -> + k (SEYes sp3 sub3) + (withInj2 minj13 mTinj13 $ \inj13 tinj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) (tinj13 e1b)) + (withInj2 minj23 mTinj23 $ \inj23 tinj23 -> + \e2 -> eunPair e2 $ \_ e2a e2b -> + EPair ext (inj23 e2a) (tinj23 e2b)) + (\e1 e2 -> + ELet ext e1 $ + ELet ext (weakenExpr WSink e2) $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) + (EFst ext (EVar ext (typeOf e2) IZ))) + (plus + (ESnd ext (EVar ext (typeOf e1) (IS IZ))) + (ESnd ext (EVar ext (typeOf e2) IZ)))) + +expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs + -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0)) +expandSubenvZeros _ SNil SETop _ = ENil ext +expandSubenvZeros w (SCons t ts) (SEYes sp sub) e = + eunPair e $ \w1 e1 e2 -> + EPair ext + (expandSubenvZeros (w1 .> WPop w) ts sub e1) + (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2) +expandSubenvZeros w (SCons t ts) (SENo sub) e = + EPair ext + (expandSubenvZeros (WPop w) ts sub e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) + + +--------------------------------- ACCUMULATORS --------------------------------- + +fromArrayValId :: Maybe (ValId t) -> Maybe Int +fromArrayValId (Just (VIArr i _)) = Just i +fromArrayValId _ = Nothing + +accumPromote :: forall dt env sto proxy r. + proxy dt + -> Descr env sto + -> (forall stoRepl envPro. + (Select env stoRepl "merge" ~ '[]) + => Descr env stoRepl + -- ^ A revised environment description that switches + -- arrays (used in the OccEnv) that are currently on + -- "merge" storage, to "accum" storage. + -> SList STy envPro + -- ^ New entries on top of the original dual environment, + -- that house the accumulators for the promoted arrays in + -- the original environment. + -> Subenv (Select env sto "merge") envPro + -- ^ The promoted entries were merge entries in the + -- original environment. + -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum")) + -- ^ All entries that were accumulators are still + -- accumulators. + -> VarMap Int (D2AcE (Select env stoRepl "accum")) + -- ^ Accumulator map for _only_ the the newly allocated + -- accumulators. + -> (forall shbinds. + SList STy shbinds + -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) + :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum")))) + -- ^ A weakening that converts a computation in the + -- revised environment to one in the original environment + -- extended with some accumulators. + -> r) + -> r +accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId) +accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of + -- Accumulators are left as-is + SAccum -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SAccum)) + envpro + prosub + (SEYesR accrevsub) + (VarMap.sink1 accumMap) + (\shbinds -> + autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) + (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) + (#pro :++: #d :++: #shb :++: #acc :++: #tl) + .> WCopy (wf shbinds) + .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) + (#d :++: #shb :++: #acc :++: #tl) + (#acc :++: (#d :++: #shb :++: #tl))) + + SMerge -> case t of + -- Discrete values are left as-is + _ | isDiscrete t -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf -> + k (storepl `DPush` (t, vid, SDiscr)) + envpro + (SENo prosub) + accrevsub + accumMap' + wf + + -- Values with "merge" storage are promoted to an accumulator in envPro + _ -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SAccum)) + (t `SCons` envpro) + (SEYesR prosub) + (SENo accrevsub) + (let accumMap' = VarMap.sink1 accumMap + in case fromArrayValId vid of + Just i -> VarMap.insert i (STAccum (d2M t)) IZ accumMap' + Nothing -> accumMap') + (\(shbinds :: SList _ shbinds) -> + let shbindsC = slistMap (\_ -> Const ()) shbinds + in + -- wf: + -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WCopy wf: + -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WPICK: ^ THESE TWO || + -- goal: | ARE EQUAL || + -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + WCopy (wf shbinds) + .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC) + (WId @(D2AcE (Select env1 stoRepl "accum")))) + + -- Discrete values are left as-is, nothing to do + SDiscr -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SDiscr)) + envpro + prosub + accrevsub + accumMap + wf + where + isDiscrete :: STy t' -> Bool + isDiscrete = \case + STNil -> True + STPair a b -> isDiscrete a && isDiscrete b + STEither a b -> isDiscrete a && isDiscrete b + STLEither a b -> isDiscrete a && isDiscrete b + STMaybe a -> isDiscrete a + STArr _ a -> isDiscrete a + STScal st -> case st of + STI32 -> True + STI64 -> True + STF32 -> False + STF64 -> False + STBool -> True + STAccum{} -> False + + +---------------------------- RETURN TRIPLE FROM CHAD --------------------------- + +data Ret env0 sto sd t = + forall shbinds tapebinds contribs. + Ret (Bindings Ex (D1E env0) shbinds) -- shared binds + (Subenv shbinds tapebinds) + (Ex (Append shbinds (D1E env0)) (D1 t)) + (SubenvS (D2E (Select env0 sto "merge")) contribs) + (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)) +deriving instance Show (Ret env0 sto sd t) + +type data TyTyPair = MkTyTyPair Ty Ty + +data SingleRet env0 sto (pair :: TyTyPair) = + forall shbinds tapebinds. + SingleRet + (Bindings Ex (D1E env0) shbinds) -- shared binds + (Subenv shbinds tapebinds) + (RetPair env0 sto (D1E env0) shbinds tapebinds pair) + +-- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds +-- -> Subenv shbinds tapebinds +-- -> Ex (Append shbinds (D1E env0)) (D1 t) +-- -> SubenvS (D2E (Select env0 sto "merge")) contribs +-- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) +-- -> SingleRet env0 sto (MkTyTyPair sd t) +-- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2) +-- {-# COMPLETE Ret1 #-} + +data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where + RetPair :: forall sd t contribs -- existentials + env0 sto env shbinds tapebinds. -- universals + Ex (Append shbinds env) (D1 t) + -> SubenvS (D2E (Select env0 sto "merge")) contribs + -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) + -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t) +deriving instance Show (RetPair env0 sto env shbinds tapebinds pair) + +data Rets env0 sto env list = + forall shbinds tapebinds. + Rets (Bindings Ex env shbinds) + (Subenv shbinds tapebinds) + (SList (RetPair env0 sto env shbinds tapebinds) list) +deriving instance Show (Rets env0 sto env list) + +toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t) +toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2) + +weakenRetPair :: SList STy shbinds -> env :> env' + -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair +weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2 + +weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list +weakenRets w (Rets binds tapesub list) = + let (binds', _) = weakenBindingsE w binds + in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) + +rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f. + Descr env0 sto + -> SList f b1 -> SList f b2 + -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2 + -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair + -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair +rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2) + | Refl <- lemAppendAssoc @b2 @b1 @env = + RetPair e1 sub + (weakenExpr (autoWeak + (#d (auto1 @sd) + &. #t2 (subList b2 subtape2) + &. #t1 (subList b1 subtape1) + &. #tl (d2ace (select SAccum descr))) + (#d :++: (#t2 :++: #tl)) + (#d :++: ((#t2 :++: #t1) :++: #tl))) + e2) + +retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list +retConcat _ SNil = Rets BTop SETop SNil +retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list) + | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs + <- weakenRets (sinkWithBindings e0) (retConcat descr list) + , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0) + , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum")) + = Rets (bconcat e0 binds) + (subenvConcat subtape subtape2) + (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1) + sub + (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2)) + (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds) + subtape subtape2) + pairs)) + +freezeRet :: Descr env sto + -> Ret env sto (D2 t) t + -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) +freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) = + let (e0', wInsertD2Ac) = weakenBindingsE (WSink .> wSinks (d2ace (select SAccum descr))) e0 + e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 + tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub)) + library = #d (auto1 @(D2 t)) + &. #tape (subList (bindingsBinds e0) subtape) + &. #shbinds (bindingsBinds e0) + &. #d2ace (d2ace (select SAccum descr)) + &. #tl (desD1E descr) + &. #contribs (SCons tContribs SNil) + in letBinds e0' $ + EPair ext + (weakenExpr wInsertD2Ac e1) + (ELet ext (weakenExpr (autoWeak library + (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) + (#shbinds :++: #d :++: #d2ace :++: #tl)) + e2') $ + expandSubenvZeros + (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl) + .> wUndoSubenv (subenvD1E (selectSub SMerge descr))) + (select SMerge descr) sub (EVar ext tContribs IZ)) + + +---------------------------- THE CHAD TRANSFORMATION --------------------------- + +drev :: forall env sto sd t. + (?config :: CHADConfig) + => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> Sparse (D2 t) sd + -> Expr ValId env t -> Ret env sto sd t +drev des _ sd | isAbsent sd = + \e -> + Ret BTop + SETop + (drevPrimal des e) + (subenvNone (d2e (select SMerge des))) + (ENil ext) +drev _ _ SpAbsent = error "Absent should be isAbsent" + +drev des accumMap (SpSparse sd) = + \e -> + case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> + Ret e0 + subtape + e1 + sub' + (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ) + (inj2 (ENil ext)) + (inj1 (weakenExpr (WCopy WSink) e2))) + } + +drev des accumMap sd = \case + EVar _ t i -> + case conv2Idx des i of + Idx2Ac accI -> + Ret BTop + SETop + (EVar ext (d1 t) (conv1Idx i)) + (subenvNone (d2e (select SMerge des))) + (let ty = applySparse sd (d2M t) + in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) + + Idx2Me tupI -> + Ret BTop + SETop + (EVar ext (d1 t) (conv1Idx i)) + (subenvOnehot (d2e (select SMerge des)) tupI sd) + (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ)) + + Idx2Di _ -> + Ret BTop + SETop + (EVar ext (d1 t) (conv1Idx i)) + (subenvNone (d2e (select SMerge des))) + (ENil ext) + + ELet _ (rhs :: Expr _ _ a) body + | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge + , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body + , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs + , let (body0', wbody0') = weakenBindingsE (WCopy (sinkWithBindings rhs0)) body0 + , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds + , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env) + , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) + -> + subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> + let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in + Ret (bconcat (rhs0 `bpush` rhs1) body0') + (subenvConcat subtapeRHS subtapeBody) + (weakenExpr wbody0' body1) + subBoth + (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd) + &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody) + &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) + &. #tl (d2ace (select SAccum des))) + (#d :++: #body :++: #tl) + (#d :++: (#body :++: #rhs) :++: #tl)) + body2) $ + ELet ext + (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ + weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ + plus_RHS_Body + (EVar ext (contribTupTy des subRHS) IZ) + (EFst ext (EVar ext bodyResType (IS IZ)))) + + EPair _ a b + | SpPair sd1 sd2 <- sd + , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) + <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil + , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) -> + subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> + Ret binds + subtape + (EPair ext a1 b1) + subBoth + (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) + (weakenExpr (WCopy WSink) a2)) $ + ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) + (weakenExpr (WCopy (WSink .> WSink)) b2)) $ + plus_A_B + (EVar ext (contribTupTy des subA) (IS IZ)) + (EVar ext (contribTupTy des subB) IZ)) + + EFst _ e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e + , STPair t1 _ <- typeOf e -> + Ret e0 + subtape + (EFst ext e1) + sub + (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $ + weakenExpr (WCopy WSink) e2) + + ESnd _ e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e + , STPair _ t2 <- typeOf e -> + Ret e0 + subtape + (ESnd ext e1) + sub + (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $ + weakenExpr (WCopy WSink) e2) + + -- Don't need to handle ENil, because its cotangent is always absent! + -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext) + + EInl _ t2 e + | SpLEither sd1 sd2 <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> + Ret e0 + subtape + (EInl ext (d1 t2) e1) + sub' + (ELCase ext + (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ) + (inj2 $ ENil ext) + (inj1 $ weakenExpr (WCopy WSink) e2) + (EError ext (contribTupTy des sub') "inl<-dinr")) + + EInr _ t1 e + | SpLEither sd1 sd2 <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> + Ret e0 + subtape + (EInr ext (d1 t1) e1) + sub' + (ELCase ext + (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ) + (inj2 $ ENil ext) + (EError ext (contribTupTy des sub') "inr<-dinl") + (inj1 $ weakenExpr (WCopy WSink) e2)) + + ECase _ e (a :: Expr _ _ t) b + | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e + , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && typeHasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge + , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && typeHasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge + , let (bindids1, bindids2) = validSplitEither (extOf e) + , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2 + <- drevScoped des accumMap t1 storage1 bindids1 sd a + , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2 + <- drevScoped des accumMap t2 storage2 bindids2 sd b + , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e + , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) + , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) + , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA + , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB + , let tapeA = tapeTy subtapeListA + , let tapeB = tapeTy subtapeListB + , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env))) + (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA + , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env))) + (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB + , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) + , let (a0', wa0') = weakenBindingsE (WCopy (sinkWithBindings e0)) a0 + , let (b0', wb0') = weakenBindingsE (WCopy (sinkWithBindings e0)) b0 + , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a]) + , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b]) + , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env) + , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env) + , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env)) + , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env)) + -> + subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ -> + subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E -> + Ret (e0 `bpush` ECase ext e1 + (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'')))) + (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0''))))) + (SEYesR subtapeE) + (EFst ext (EVar ext tPrimal IZ)) + subOut + (elet + (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) + (let (rebinds, prerebinds) = reconstructBindings subtapeListA + in letBinds (rebinds IZ) $ + ELet ext + (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $ + elet + (weakenExpr (autoWeak (#d (auto1 @sd) + &. #ta0 subtapeListA + &. #prea0 prerebinds + &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) + &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) + &. #tl (d2ace (select SAccum des))) + (#d :++: #ta0 :++: #tl) + (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl)) + a2) $ + EPair ext (sAB_A $ EFst ext (evar IZ)) + (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ)))) + (let (rebinds, prerebinds) = reconstructBindings subtapeListB + in letBinds (rebinds IZ) $ + ELet ext + (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $ + elet + (weakenExpr (autoWeak (#d (auto1 @sd) + &. #tb0 subtapeListB + &. #preb0 prerebinds + &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) + &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) + &. #tl (d2ace (select SAccum des))) + (#d :++: #tb0 :++: #tl) + (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl)) + b2) $ + EPair ext (sAB_B $ EFst ext (evar IZ)) + (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $ + plus_AB_E + (EFst ext (evar IZ)) + (ELet ext (ESnd ext (evar IZ)) $ + weakenExpr (WCopy (wSinks' @[_,_,_])) e2)) + + EConst _ t val -> + Ret BTop + SETop + (EConst ext t val) + (subenvNone (d2e (select SMerge des))) + (ENil ext) + + EOp _ op e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e -> + case d2op op of + Linear d2opfun -> + Ret e0 + subtape + (d1op op e1) + sub + (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) + (weakenExpr (WCopy WSink) e2)) + Nonlinear d2opfun -> + Ret (e0 `bpush` e1) + (SEYesR subtape) + (d1op op $ EVar ext (d1 (typeOf e)) IZ) + sub + (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) + (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) + (weakenExpr (WCopy (wSinks' @[_,_])) e2)) + + ECustom _ _ tb _ srce pr du a b + -- allowed to ignore a2 because 'a' is the part of the input that is inactive + | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b -> + case isDense (d2M (typeOf srce)) sd of + Just Refl -> + Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) + `bpush` weakenExpr WSink b1 + `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr) + `bpush` ESnd ext (EVar ext (typeOf pr) IZ)) + (SEYesR (SENo (SENo (SENo bsubtape)))) + (EFst ext (EVar ext (typeOf pr) (IS IZ))) + bsub + (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ + weakenExpr (WCopy (WSink .> WSink)) b2) + + Nothing -> + Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) + `bpush` weakenExpr WSink b1 + `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) + (SEYesR (SENo (SENo bsubtape))) + (EFst ext (EVar ext (typeOf pr) IZ)) + bsub + (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $ -- tape + ELet ext (expandSparse (typeOf srce) sd -- expanded incoming cotangent + (EFst ext (EVar ext (typeOf pr) (IS (IS IZ)))) + (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $ + ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2) + + ERecompute _ e -> + deleteUnused (descrList des) (occCountAll e) $ \usedSub -> + let smallE = unsafeWeakenWithSubenv usedSub e in + subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> + case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 -> + let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in + Ret (collectBindings (desD1E des) subD1eUsed) + (subenvAll (desD1E usedDes)) + (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e) + (subenvCompose subMergeUsed' sub) + (letBinds (fst (weakenBindingsE (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ + weakenExpr + (autoWeak (#d (auto1 @sd) + &. #shbinds (bindingsBinds e0) + &. #tape (subList (bindingsBinds e0) subtape) + &. #d1env (desD1E usedDes) + &. #tl' (d2ace (select SAccum usedDes)) + &. #tl (d2ace (select SAccum des))) + (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed)) + (#shbinds :++: #d :++: #d1env :++: #tl)) + e2) + } + + EError _ t s -> + Ret BTop + SETop + (EError ext (d1 t) s) + (subenvNone (d2e (select SMerge des))) + (ENil ext) + + EConstArr _ n t val -> + Ret BTop + SETop + (EConstArr ext n t val) + (subenvNone (d2e (select SMerge des))) + (ENil ext) + + EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ eltty) + | SpArr @_ @sdElt sdElt <- sd + , let eltty = typeOf ef + , shty :: STy shty <- tTup (sreplicate ndim tIx) + , Refl <- indexTupD1Id ndim -> + drevLambda des accumMap (shty, SDiscr) sdElt ef $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 -> + let library = #ix (shty `SCons` SNil) + &. #e0 (bindingsBinds e0) + &. #propr (d1e provars) + &. #d1env (desD1E des) + &. #d (auto1 @sdElt) + &. #tape (auto1 @e_tape) + &. #pro (d2ace provars) + &. #d2acEnv (d2ace (select SAccum des)) + &. #darr (auto1 @(TArr ndim sdElt)) + &. #tapearr (auto1 @(TArr ndim e_tape)) in + Ret (proPrimalBinds + `bpush` weakenExpr (wSinks (d1e provars)) + (EBuild ext ndim + (drevPrimal des she) + (letBinds e0 $ + EPair ext e1 e1tape)) + `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) IZ)) + (SEYesR (SENo (subenvAll (d1e provars)))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) (IS IZ))) + (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) esub) + (let sinkOverEnvPro = wSinks @(sd : TArr ndim e_tape : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace provars) in + ESnd ext $ + wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $ + EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $ + -- the cotangent for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ + -- the tape for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) + (EVar ext shty (IS IZ))) $ + weakenExpr (autoWeak library (#tape :++: #d :++: #pro :++: #d2acEnv) + (#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) + e2) + + EMap _ ef (earr :: Expr _ _ (TArr n a)) + | SpArr sdElt <- sd + , let STArr ndim t1 = typeOf earr + t2 = typeOf ef -> + drevLambda des accumMap (t1, SMerge) sdElt ef $ \provars efsub proPrimalBinds ef0 ef1 ef1tape spEf wrapAccum ef2 -> + case drev des accumMap (SpArr spEf) earr of { Ret ea0 easubtape ea1 easub ea2 -> + let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings ea0) proPrimalBinds + ttape = typeOf ef1tape + library = #d1env (desD1E des) + &. #a0 (bindingsBinds ea0) + &. #atapebinds (subList (bindingsBinds ea0) easubtape) + &. #propr (d1e provars) + &. #x (d1 t1 `SCons` SNil) + &. #parr (STArr ndim (d1 t1) `SCons` SNil) + &. #tapearr (STArr ndim ttape `SCons` SNil) + &. #darr (STArr ndim (applySparse sdElt (d2 t2)) `SCons` SNil) + &. #dy (applySparse sdElt (d2 t2) `SCons` SNil) + &. #tape (ttape `SCons` SNil) + &. #dytape (STPair (applySparse sdElt (d2 t2)) ttape `SCons` SNil) + &. #d2acEnv (d2ace (select SAccum des)) + &. #pro (d2ace provars) + in + subenvPlus SF SF (d2eM (select SMerge des)) (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) easub $ \subfa _ _ plus_f_a -> + Ret (bconcat ea0 proPrimalBinds' + `bpush` weakenExpr (autoWeak library (#a0 :++: #d1env) ((#propr :++: #a0) :++: #d1env)) ea1 + `bpush` emap (weakenExpr (autoWeak library (#x :++: #d1env) (#x :++: #parr :++: (#propr :++: #a0) :++: #d1env)) + (letBinds ef0 $ + EPair ext ef1 ef1tape)) + (EVar ext (STArr ndim (d1 t1)) IZ) + `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) IZ)) + (SEYesR (SENo (SENo (subenvConcat easubtape (subenvAll (d1e provars)))))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) (IS IZ))) + subfa + (let layout = #darr :++: #tapearr :++: (#propr :++: #atapebinds) :++: #d2acEnv in + elet + (wrapAccum (autoWeak library #propr layout) $ + emap (elet (EFst ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) IZ)) $ + elet (ESnd ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) (IS IZ))) $ + weakenExpr (autoWeak library (#tape :++: #dy :++: #pro :++: #d2acEnv) + (#tape :++: #dy :++: #dytape :++: #pro :++: layout)) + ef2) + (ezip (EVar ext (STArr ndim (applySparse sdElt (d2 t2))) (autoWeak library #darr (#pro :++: layout) @> IZ)) + (EVar ext (STArr ndim ttape) (autoWeak library #tapearr (#pro :++: layout) @> IZ)))) $ + plus_f_a + (ESnd ext (evar IZ)) + (weakenExpr (WCopy (autoWeak library (#atapebinds :++: #d2acEnv) layout)) + (subst0 (EFst ext (EVar ext (STPair (STArr ndim (typeOf ef2)) (tTup (d2e provars))) IZ)) + ea2))) + } + + EFold1Inner _ commut origef ex₀ earr + | SpArr @_ @sdElt sdElt <- sd + , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr + , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil) + <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil -> + drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) spEf wrapAccum ef2 -> + let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings bindsx₀a) proPrimalBinds in + let bogEltTy = STPair (STPair (d1 eltty) (d1 eltty)) (typeOf ef1tape) + bogTy = STArr (SS ndim) bogEltTy + primalTy = STPair (STArr ndim (d1 eltty)) bogTy + library = #xy (STPair (d1 eltty) (d1 eltty) `SCons` SNil) + &. #parr (auto1 @(TArr (S n) (D1 elt))) + &. #px₀ (auto1 @(D1 elt)) + &. #px (auto1 @(D1 elt)) + &. #pzi (auto1 @(ZeroInfo (D2 elt))) + &. #primal (primalTy `SCons` SNil) + &. #darr (auto1 @(TArr n sdElt)) + &. #d (auto1 @(D2 elt)) + &. #x₀abinds (bindingsBinds bindsx₀a) + &. #fbinds (bindingsBinds ef0) + &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a) + &. #ftape (auto1 @ef_tape) + &. #bogelt (bogEltTy `SCons` SNil) + &. #propr (d1e provars) + &. #d1env (desD1E des) + &. #d2acEnv (d2ace (select SAccum des)) + &. #d2acPro (d2ace provars) + &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro)))) + wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in + subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a -> + subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) $ \subx₀af _ _ plus_x₀a_f -> + Ret (bconcat bindsx₀a proPrimalBinds' + `bpush` weakenExpr wOverPrimalBindings ex₀1 + `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ) + `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1 + `bpush` EFold1InnerD1 ext commut + (let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in + weakenExpr (autoWeak library (#xy :++: #d1env) layout) + (letBinds ef0 $ + EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape) + ef1 + (EPair ext + (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: #xy :++: #d1env) @> IZ)) + ef1tape))) + (EVar ext (d1 eltty) (IS (IS IZ))) + (EVar ext (STArr (SS ndim) (d1 eltty)) IZ)) + (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e provars))))))) + (EFst ext (EVar ext primalTy IZ)) + subx₀af + (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in + elet + (wrapAccum (autoWeak library #propr layout1) $ + let layout2 = #d2acPro :++: layout1 in + EFold1InnerD2 ext commut + (elet (ESnd ext (EVar ext bogEltTy (IS IZ))) $ + let layout3 = #ftape :++: #d :++: #bogelt :++: layout2 in + expandSparse (STPair eltty eltty) spEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $ + weakenExpr (autoWeak library (#ftape :++: #d :++: #d2acPro :++: #d2acEnv) layout3) ef2) + (ESnd ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) + (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) + (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) + (EFst ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))))) $ + plus_x₀a_f + (plus_x₀_a + (elet (EIdx0 ext + (EFold1Inner ext Commut + (let t = STPair (d2 eltty) (d2 eltty) + in EPlus ext (d2M eltty) (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) + (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ))) + (eflatten (EFst ext (EFst ext (evar IZ)))))) $ + weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) + ex₀2) + (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ + subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) + (ESnd ext (evar IZ))) + + EUnit _ e + | SpArr sdElt <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e -> + Ret e0 + subtape + (EUnit ext e1) + sub + (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $ + weakenExpr (WCopy WSink) e2) + + EReplicate1Inner _ en e + -- We're allowed to differentiate 'en' as primal-only here because its output is discrete. + | SpArr sdElt <- sd + , let STArr ndim eltty = typeOf e -> + -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero. + sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 -> + Ret binds + subtape + (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1) + sub + (ELet ext (EFold1Inner ext Commut + (let t = STPair (applySparse sdElt' (d2 eltty)) (applySparse sdElt' (d2 eltty)) + in sparsePlus (d2M eltty) sdElt' (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) + (inj2 (ENil ext)) + (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $ + weakenExpr (WCopy WSink) e2) + } + + EIdx0 _ e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e + , STArr _ t <- typeOf e -> + Ret e0 + subtape + (EIdx0 ext e1) + sub + (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $ + weakenExpr (WCopy WSink) e2) + + EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" + {- + EIdx1 _ e ei + -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. + | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) + <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil + , STArr (SS n) eltty <- typeOf e -> + Ret (binds `bpush` e1 + `bpush` EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)) + (SEYesR (SENo subtape)) + (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)) + (weakenExpr (WSink .> WSink) ei1)) + sub + (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (EVar ext (STArr n (d2 eltty)) (IS IZ))) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + -} + + EIdx _ e ei + -- We're allowed to differentiate ei as primal because its output is discrete. + | STArr n eltty <- typeOf e + , Refl <- indexTupD1Id n + , let tIxN = tTup (sreplicate n tIx) -> + sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 -> + Ret (binds `bpush` e1 + `bpush` EShape ext (EVar ext (typeOf e1) IZ) + `bpush` weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei)) + (SEYesR (SEYesR (SENo subtape))) + (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) + sub + (ELet ext + (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) + (SAPArrIdx SAPHere) + (EPair ext + (EPair ext (EVar ext tIxN (IS IZ)) + (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $ + makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext)))) + (ENil ext)) + (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } + + EShape _ e + -- Allowed to differentiate e as primal because the output of EShape is + -- discrete, hence we'd be passing a zero cotangent to e anyway. + | STArr n _ <- typeOf e + , Refl <- indexTupD1Id n -> + Ret BTop + SETop + (EShape ext (drevPrimal des e)) + (subenvNone (d2eM (select SMerge des))) + (ENil ext) + + ESum1Inner _ e + | SpArr sd' <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e + , STArr (SS n) t <- typeOf e -> + Ret (e0 `bpush` e1 + `bpush` EShape ext (EVar ext (STArr (SS n) t) IZ)) + (SEYesR (SENo subtape)) + (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) + sub + (ELet ext (EReplicate1Inner ext + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + + EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e + EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e + + EReshape _ n esh e + | SpArr sd' <- sd + , STArr orign t <- typeOf e + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e + , Refl <- indexTupD1Id n -> + Ret (e0 `bpush` e1 + `bpush` EShape ext (EVar ext (STArr orign (d1 t)) IZ)) + (SEYesR (SENo subtape)) + (EReshape ext n (weakenExpr (WSink .> WSink .> wSinks (bindingsBinds e0)) (drevPrimal des esh)) + (EVar ext (STArr orign (d1 t)) (IS IZ))) + sub + (elet (EReshape ext orign (EVar ext (tTup (sreplicate orign tIx)) (IS IZ)) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + + EZip _ a b + | SpArr sd' <- sd + , STArr n t1 <- typeOf a + , STArr _ t2 <- typeOf b -> + splitSparsePair (STPair (d2 t1) (d2 t2)) sd' $ \sd1 sd2 pairSplitE -> + case retConcat des (toSingleRet (drev des accumMap (SpArr sd1) a) `SCons` + toSingleRet (drev des accumMap (SpArr sd2) b) `SCons` SNil) of + { Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) -> + subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> + Ret binds + subtape + (EZip ext a1 b1) + subBoth + (case pairSplitE of + Left Refl -> + let t' = STArr n (STPair (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 t2))) in + plus_A_B + (elet (emap (EFst ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) a2) + (elet (emap (ESnd ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) b2) + Right f -> f IZ $ \wrapPair pick1 pick2 -> + elet (emap (wrapPair (EPair ext pick1 pick2)) + (EVar ext (applySparse (SpArr sd') (STArr n (STPair (d2 t1) (d2 t2)))) IZ)) $ + plus_A_B + (elet (emap (EFst ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) a2) + (elet (emap (ESnd ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) b2)) + } + + ENothing{} -> err_unsupported "ENothing" + EJust{} -> err_unsupported "EJust" + EMaybe{} -> err_unsupported "EMaybe" + ELNil{} -> err_unsupported "ELNil" + ELInl{} -> err_unsupported "ELInl" + ELInr{} -> err_unsupported "ELInr" + ELCase{} -> err_unsupported "ELCase" + + EWith{} -> err_accum + EZero{} -> err_monoid + EDeepZero{} -> err_monoid + EPlus{} -> err_monoid + EOneHot{} -> err_monoid + + EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" + EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" + + where + err_accum = error "Accumulator operations unsupported in the source program" + err_monoid = error "Monoid operations unsupported in the source program" + err_unsupported s = error $ "CHAD: unsupported " ++ s + err_targetlang s = error $ "CHAD: Target language operation " ++ s ++ " not supported in source program" + + contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) + contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) + +deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True) + => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) + -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> Sparse (D2s t) sd + -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t)) +deriv_extremum extremum des accumMap sd e + | at@(STArr (SS n) t@(STScal st)) <- typeOf e + , let at' = STArr n t + , let tIxN = tTup (sreplicate (SS n) tIx) = + sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 -> + Ret (e0 `bpush` e1 + `bpush` extremum (EVar ext at IZ)) + (SEYesR (SEYesR subtape)) + (EVar ext at' IZ) + sub + (ELet ext + (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ + eif (EOp ext (OEq st) (EPair ext + (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) + (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))) + (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) + (inj2 (ENil ext))) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } + +data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) + +data RetScoped env0 sto a s sd t = + forall shbinds tapebinds contribs sa. + RetScoped + (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds + (Subenv (Append shbinds '[D1 a]) tapebinds) + (Ex (Append shbinds (D1E (a : env0))) (D1 t)) + (SubenvS (D2E (Select env0 sto "merge")) contribs) + -- ^ merge contributions to the _enclosing_ merge environment + (Sparse (D2 a) sa) + -- ^ contribution to the argument + (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) + (If (s == "discr") (Tup contribs) + (TPair (Tup contribs) sa))) + -- ^ the merge contributions, plus the cotangent to the argument + -- (if there is any) +deriving instance Show (RetScoped env0 sto a s sd t) + +drevScoped :: forall a s env sto sd t. + (?config :: CHADConfig) + => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> STy a -> Storage s -> Maybe (ValId a) + -> Sparse (D2 t) sd + -> Expr ValId (a : env) t + -> RetScoped env sto a s sd t +drevScoped des accumMap argty argsto argids sd expr = case argsto of + SMerge + | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr + , Refl <- lemAppendNil @tapebinds -> + case sub of + SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2 + SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext)) + + SAccum + | chcSmartWith ?config + , Just (VIArr i _) <- argids + , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap + , Just Refl <- testEquality foundTy (STAccum (d2M argty)) + , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr + , Refl <- lemAppendNil @tapebinds -> + -- Our contribution to the binding's cotangent _here_ is zero (absent), + -- because we're contributing to an earlier binding of the same value + -- instead. + RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $ + let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in + ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $ + weakenExpr (autoWeak (#d (auto1 @sd) + &. #body (subList (bindingsBinds e0) subtape) + &. #ac (auto1 @(TAccum (D2 a))) + &. #tl (d2ace (select SAccum des))) + (#d :++: #body :++: #ac :++: #tl) + (#ac :++: #d :++: #body :++: #tl)) + (EPair ext e2 (ENil ext)) + + | let accumMap' = case argids of + Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap) + _ -> VarMap.sink1 accumMap + , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr -> + let library = #d (auto1 @sd) + &. #p (auto1 @(D1 a)) + &. #body (subList (bindingsBinds e0) subtape) + &. #ac (auto1 @(TAccum (D2 a))) + &. #tl (d2ace (select SAccum des)) + in + RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $ + let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in + EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $ + weakenExpr (autoWeak library + (#d :++: #body :++: #ac :++: #tl) + (#ac :++: #d :++: (#body :++: #p) :++: #tl)) + e2 + + SDiscr + | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr + , Refl <- lemAppendNil @tapebinds -> + RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2 + +drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False) + => Descr env sto + -> VarMap Int (D2AcE (Select env sto "accum")) + -> (STy a, Storage s) + -> Sparse (D2 t) dt + -> Expr ValId (a : env) t + -> (forall provars shbinds tape d2a'. + SList STy provars + -> Subenv (D2E (Select env sto "merge")) (D2E provars) + -> Bindings Ex (D1E env) (D1E provars) -- accum-promoted free variables of which we need a primal in the reverse pass (to initialise the accumulator) + -> Bindings Ex (D1 a : D1E env) shbinds + -> Ex (Append shbinds (D1 a : D1E env)) (D1 t) + -> Ex (Append shbinds (D1 a : D1E env)) tape + -> Sparse (D2 a) d2a' + -> (forall env' b. + D1E provars :> env' + -> Ex (Append (D2AcE provars) env') b + -> Ex ( env') (TPair b (Tup (D2E provars)))) + -> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' + -> r) + -> r +drevLambda des accumMap (argty, argsto) sd origef k = + let t = typeOf origef in + deleteUnused (descrList des) (occEnvPopSome (occCountAll origef)) $ \(usedSub :: Subenv env env') -> + let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef in + subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> + accumPromote (applySparse sd (d2 t)) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> + let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in + let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in + let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in + case prf1 prodes argty argsto of { Refl -> + case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> + let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in + extractContrib prodes argty argsto subEf $ \argSp getSparseArg -> + let library = #fbinds (bindingsBinds ef0) + &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) + &. #ftape (auto1 @(Tape e_tape)) + &. #arg (d1 argty `SCons` SNil) + &. #d (applySparse sd (d2 t) `SCons` SNil) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes) + &. #propr (d1e envPro) + &. #d2acUsed (d2ace (select SAccum usedDes)) + &. #d2acEnv (d2ace (select SAccum des)) + &. #d2acPro (d2ace envPro) + &. #efPrerebinds efPrerebinds in + k envPro + (subenvD2E (subenvCompose subMergeUsed proSub)) + mergePrimalBindings + (fst (weakenBindingsE (WCopy (wUndoSubenv subD1eUsed)) ef0)) + (weakenExpr (autoWeak library (#fbinds :++: #arg :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#fbinds :++: #arg :++: #d1env)) + ef1) + (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #arg :++: #d1env))) + argSp + (\wpro1 body -> + uninvertTup (d2e envPro) (typeOf body) $ + makeAccumulators wpro1 envPro $ + body) + (letBinds (efRebinds IZ) $ + weakenExpr + (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) + ((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv) + .> wPro (subList (bindingsBinds ef0) subtapeEf)) + (getSparseArg ef2)) + }} + where + extractContrib :: (Select env sto "merge" ~ '[], (s == "accum") ~ False) + => proxy env sto -> proxy2 a -> Storage s + -- if s == "merge", this simplifies to SubenvS '[D2 a] t' + -- if s == "discr", this simplifies to SubenvS '[] t' + -> SubenvS (D2E (Select (a : env) (s : sto) "merge")) t' + -> (forall d'. Sparse (D2 a) d' -> (forall env'. Ex env' (Tup t') -> Ex env' d') -> r) -> r + extractContrib _ _ SMerge (SENo SETop) k' = k' SpAbsent id + extractContrib _ _ SMerge (SEYes s SETop) k' = k' s (ESnd ext) + extractContrib _ _ SDiscr SETop k' = k' SpAbsent id + + prf1 :: (s == "accum") ~ False => proxy env sto -> proxy2 a -> Storage s + -> Select (a : env) (s : sto) "accum" :~: Select env sto "accum" + prf1 _ _ SMerge = Refl + prf1 _ _ SDiscr = Refl + +-- TODO: proper primal-only transform that doesn't depend on D1 = Id +drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) +drevPrimal des e + | Refl <- d1Identity (typeOf e) + , Refl <- d1eIdentity (descrList des) + = mapExt (const ext) e diff --git a/src/CHAD/Drev/Accum.hs b/src/CHAD/Drev/Accum.hs new file mode 100644 index 0000000..6f25f11 --- /dev/null +++ b/src/CHAD/Drev/Accum.hs @@ -0,0 +1,72 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} +-- | TODO this module is a grab-bag of random utility functions that are shared +-- between CHAD.Drev and CHAD.Drev.Top. +module CHAD.Drev.Accum where + +import CHAD.AST +import CHAD.Data +import CHAD.Drev.Types +import CHAD.AST.Env + + +d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) +d2zeroInfo STNil _ = ENil ext +d2zeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2) +d2zeroInfo STEither{} _ = ENil ext +d2zeroInfo STLEither{} _ = ENil ext +d2zeroInfo STMaybe{} _ = ENil ext +d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e +d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext +d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" + +d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t)) +d2deepZeroInfo STNil _ = ENil ext +d2deepZeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2deepZeroInfo a e1) (d2deepZeroInfo b e2) +d2deepZeroInfo (STEither a b) e = + ECase ext e + (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) + (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) +d2deepZeroInfo (STLEither a b) e = + elcase e + (ELNil ext (tDeepZeroInfo (d2M a)) (tDeepZeroInfo (d2M b))) + (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) + (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) +d2deepZeroInfo (STMaybe a) e = + emaybe e + (ENothing ext (tDeepZeroInfo (d2M a))) + (EJust ext (d2deepZeroInfo a (EVar ext (d1 a) IZ))) +d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e +d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext +d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program" + +-- The weakening is necessary because we need to initialise the created +-- accumulators with zeros. Those zeros are deep and need full primals. This +-- means, in the end, that primals corresponding to environment entries +-- promoted to an accumulator with accumPromote in CHAD need to be stored for +-- the dual. +makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) +makeAccumulators _ SNil e = e +makeAccumulators w (t `SCons` envpro) e = + makeAccumulators (WPop w) envpro $ + EWith ext (d2M t) (EDeepZero ext (d2M t) (d2deepZeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e + +uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) +uninvertTup SNil _ e = EPair ext e (ENil ext) +uninvertTup (t `SCons` list) tcore e = + ELet ext (uninvertTup list (STPair tcore t) e) $ + let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding + in EPair ext + (EFst ext (EFst ext (EVar ext recT IZ))) + (EPair ext + (ESnd ext (EVar ext recT IZ)) + (ESnd ext (EFst ext (EVar ext recT IZ)))) + +subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env') +subenvD1E SETop = SETop +subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) +subenvD1E (SENo sub) = SENo (subenvD1E sub) diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/Drev/EnvDescr.hs index 4c287d7..5a90303 100644 --- a/src/CHAD/EnvDescr.hs +++ b/src/CHAD/Drev/EnvDescr.hs @@ -7,18 +7,18 @@ {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module CHAD.EnvDescr where +module CHAD.Drev.EnvDescr where import Data.Kind (Type) import Data.Some import GHC.TypeLits (Symbol) -import Analysis.Identity (ValId(..)) -import AST.Env -import AST.Types -import AST.Weaken -import CHAD.Types -import Data +import CHAD.Analysis.Identity (ValId(..)) +import CHAD.AST.Env +import CHAD.AST.Types +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types type Storage :: Symbol -> Type @@ -52,12 +52,12 @@ subDescr :: Descr env sto -> Subenv env env' -> r) -> r subDescr DTop SETop k = k DTop SETop SETop SETop -subDescr (des `DPush` (t, vid, sto)) (SEYes sub) k = +subDescr (des `DPush` (t, vid, sto)) (SEYesR sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of - SMerge -> k (des' `DPush` (t, vid, sto)) (SEYes submerge) subaccum (SEYes subd1e) - SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYes subaccum) (SEYes subd1e) - SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYes subd1e) + SMerge -> k (des' `DPush` (t, vid, sto)) (SEYesR submerge) subaccum (SEYesR subd1e) + SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYesR subaccum) (SEYesR subd1e) + SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYesR subd1e) subDescr (des `DPush` (_, _, sto)) (SENo sub) k = subDescr des sub $ \des' submerge subaccum subd1e -> case sto of @@ -82,3 +82,15 @@ select s@SDiscr (DPush des (_, _, SMerge)) = select s des select s@SAccum (DPush des (_, _, SDiscr)) = select s des select s@SMerge (DPush des (_, _, SDiscr)) = select s des select s@SDiscr (DPush des (t, _, SDiscr)) = SCons t (select s des) + +selectSub :: Storage s -> Descr env sto -> Subenv env (Select env sto s) +selectSub _ DTop = SETop +selectSub s@SAccum (DPush des (_, _, SAccum)) = SEYesR (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SMerge)) = SEYesR (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SDiscr)) = SEYesR (selectSub s des) diff --git a/src/CHAD/Top.hs b/src/CHAD/Drev/Top.hs index 2c01178..65b4dee 100644 --- a/src/CHAD/Top.hs +++ b/src/CHAD/Drev/Top.hs @@ -1,6 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ImportQualifiedPost #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} @@ -8,18 +8,20 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module CHAD.Top where +module CHAD.Drev.Top where -import Analysis.Identity -import AST -import AST.SplitLets -import AST.Weaken.Auto -import CHAD -import CHAD.Accum -import CHAD.EnvDescr -import CHAD.Types -import Data -import qualified Data.VarMap as VarMap +import CHAD.Analysis.Identity +import CHAD.AST +import CHAD.AST.Env +import CHAD.AST.Sparse +import CHAD.AST.SplitLets +import CHAD.AST.Weaken.Auto +import CHAD.Data +import CHAD.Data.VarMap qualified as VarMap +import CHAD.Drev +import CHAD.Drev.Accum +import CHAD.Drev.EnvDescr +import CHAD.Drev.Types type family MergeEnv env where @@ -41,38 +43,25 @@ mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r accumDescr SNil k = k DTop accumDescr (t `SCons` env) k = accumDescr env $ \des -> - if hasArrays t then k (des `DPush` (t, Nothing, SAccum)) - else k (des `DPush` (t, Nothing, SMerge)) - -d1Identity :: STy t -> D1 t :~: t -d1Identity = \case - STNil -> Refl - STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STMaybe t | Refl <- d1Identity t -> Refl - STArr _ t | Refl <- d1Identity t -> Refl - STScal _ -> Refl - STAccum{} -> error "Accumulators not allowed in input program" - -d1eIdentity :: SList STy env -> D1E env :~: env -d1eIdentity SNil = Refl -d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl + if typeHasArrays t then k (des `DPush` (t, Nothing, SAccum)) + else k (des `DPush` (t, Nothing, SMerge)) reassembleD2E :: Descr env sto + -> D1E env :> env' -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge")))) -> Ex env' (Tup (D2E env)) -reassembleD2E DTop _ = ENil ext -reassembleD2E (des `DPush` (_, _, SAccum)) e = - ELet ext e $ - EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ))) - (ESnd ext (EVar ext (typeOf e) IZ)))) - (ESnd ext (EFst ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (_, _, SMerge)) e = - ELet ext e $ - EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ)) - (EFst ext (ESnd ext (EVar ext (typeOf e) IZ))))) - (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ))) -reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (EZero ext t) +reassembleD2E DTop _ _ = ENil ext +reassembleD2E (des `DPush` (_, _, SAccum)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e1 $ \w2 e11 e12 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12 +reassembleD2E (des `DPush` (_, _, SMerge)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e2 $ \w2 e21 e22 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22 +reassembleD2E (des `DPush` (t, _, SDiscr)) w e = + EPair ext (reassembleD2E des (WPop w) e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) chad config env (term :: Ex env t) @@ -82,21 +71,22 @@ chad config env (term :: Ex env t) let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr))) tvar = STPair t1 (tTup (d2e (select SAccum descr))) in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $ - makeAccumulators (select SAccum descr) $ + makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $ weakenExpr (autoWeak (#d (auto1 @(D2 t)) &. #acenv (d2ace (select SAccum descr)) &. #tl (d1e env)) (#d :++: #acenv :++: #tl) (#acenv :++: #d :++: #tl)) $ - freezeRet descr (drev descr VarMap.empty term')) $ + freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $ EPair ext (EFst ext (EFst ext (EVar ext tvar IZ))) - (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ)) - (ESnd ext (EFst ext (EVar ext tvar IZ))))) + (reassembleD2E descr (WSink .> WSink) + (EPair ext (ESnd ext (EVar ext tvar IZ)) + (ESnd ext (EFst ext (EVar ext tvar IZ))))) | False <- chcArgArrayAccum config , Refl <- mergeEnvNoAccum env , Refl <- mergeEnvOnlyMerge env - = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty term') + = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term') where term' = identityAnalysis env (splitLets term) diff --git a/src/CHAD/Drev/Types.hs b/src/CHAD/Drev/Types.hs new file mode 100644 index 0000000..367a974 --- /dev/null +++ b/src/CHAD/Drev/Types.hs @@ -0,0 +1,153 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Drev.Types where + +import CHAD.AST.Accum +import CHAD.AST.Types +import CHAD.Data + + +type family D1 t where + D1 TNil = TNil + D1 (TPair a b) = TPair (D1 a) (D1 b) + D1 (TEither a b) = TEither (D1 a) (D1 b) + D1 (TLEither a b) = TLEither (D1 a) (D1 b) + D1 (TMaybe a) = TMaybe (D1 a) + D1 (TArr n t) = TArr n (D1 t) + D1 (TScal t) = TScal t + +type family D2 t where + D2 TNil = TNil + D2 (TPair a b) = TPair (D2 a) (D2 b) + D2 (TEither a b) = TLEither (D2 a) (D2 b) + D2 (TLEither a b) = TLEither (D2 a) (D2 b) + D2 (TMaybe t) = TMaybe (D2 t) + D2 (TArr n t) = TArr n (D2 t) + D2 (TScal t) = D2s t + +type family D2s t where + D2s TI32 = TNil + D2s TI64 = TNil + D2s TF32 = TScal TF32 + D2s TF64 = TScal TF64 + D2s TBool = TNil + +type family D1E env where + D1E '[] = '[] + D1E (t : env) = D1 t : D1E env + +type family D2E env where + D2E '[] = '[] + D2E (t : env) = D2 t : D2E env + +type family D2AcE env where + D2AcE '[] = '[] + D2AcE (t : env) = TAccum (D2 t) : D2AcE env + +d1 :: STy t -> STy (D1 t) +d1 STNil = STNil +d1 (STPair a b) = STPair (d1 a) (d1 b) +d1 (STEither a b) = STEither (d1 a) (d1 b) +d1 (STLEither a b) = STLEither (d1 a) (d1 b) +d1 (STMaybe t) = STMaybe (d1 t) +d1 (STArr n t) = STArr n (d1 t) +d1 (STScal t) = STScal t +d1 STAccum{} = error "Accumulators not allowed in input program" + +d1e :: SList STy env -> SList STy (D1E env) +d1e SNil = SNil +d1e (t `SCons` env) = d1 t `SCons` d1e env + +d2M :: STy t -> SMTy (D2 t) +d2M STNil = SMTNil +d2M (STPair a b) = SMTPair (d2M a) (d2M b) +d2M (STEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STLEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STMaybe t) = SMTMaybe (d2M t) +d2M (STArr n t) = SMTArr n (d2M t) +d2M (STScal t) = case t of + STI32 -> SMTNil + STI64 -> SMTNil + STF32 -> SMTScal STF32 + STF64 -> SMTScal STF64 + STBool -> SMTNil +d2M STAccum{} = error "Accumulators not allowed in input program" + +d2 :: STy t -> STy (D2 t) +d2 = fromSMTy . d2M + +d2eM :: SList STy env -> SList SMTy (D2E env) +d2eM SNil = SNil +d2eM (t `SCons` ts) = d2M t `SCons` d2eM ts + +d2e :: SList STy env -> SList STy (D2E env) +d2e = slistMap fromSMTy . d2eM + +d2ace :: SList STy env -> SList STy (D2AcE env) +d2ace SNil = SNil +d2ace (t `SCons` ts) = STAccum (d2M t) `SCons` d2ace ts + + +data CHADConfig = CHADConfig + { -- | D[let] will bind variables containing arrays in accumulator mode. + chcLetArrayAccum :: Bool + , -- | D[case] will bind variables containing arrays in accumulator mode. + chcCaseArrayAccum :: Bool + , -- | Introduce top-level arguments containing arrays in accumulator mode. + chcArgArrayAccum :: Bool + , -- | Place with-blocks around array variable scopes, and redirect accumulations there. + chcSmartWith :: Bool + } + deriving (Show) + +defaultConfig :: CHADConfig +defaultConfig = CHADConfig + { chcLetArrayAccum = False + , chcCaseArrayAccum = False + , chcArgArrayAccum = False + , chcSmartWith = False + } + +chcSetAccum :: CHADConfig -> CHADConfig +chcSetAccum c = c { chcLetArrayAccum = True + , chcCaseArrayAccum = True + , chcArgArrayAccum = True + , chcSmartWith = True } + + +------------------------------------ LEMMAS ------------------------------------ + +indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) +indexTupD1Id SZ = Refl +indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + +lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil +lemZeroInfoScal STI32 = Refl +lemZeroInfoScal STI64 = Refl +lemZeroInfoScal STF32 = Refl +lemZeroInfoScal STF64 = Refl +lemZeroInfoScal STBool = Refl + +lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil +lemDeepZeroInfoScal STI32 = Refl +lemDeepZeroInfoScal STI64 = Refl +lemDeepZeroInfoScal STF32 = Refl +lemDeepZeroInfoScal STF64 = Refl +lemDeepZeroInfoScal STBool = Refl + +d1Identity :: STy t -> D1 t :~: t +d1Identity = \case + STNil -> Refl + STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STMaybe t | Refl <- d1Identity t -> Refl + STArr _ t | Refl <- d1Identity t -> Refl + STScal _ -> Refl + STAccum{} -> error "Accumulators not allowed in input program" + +d1eIdentity :: SList STy env -> D1E env :~: env +d1eIdentity SNil = Refl +d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Drev/Types/ToTan.hs index f843206..019119c 100644 --- a/src/CHAD/Types/ToTan.hs +++ b/src/CHAD/Drev/Types/ToTan.hs @@ -1,14 +1,14 @@ {-# LANGUAGE GADTs #-} -module CHAD.Types.ToTan where +module CHAD.Drev.Types.ToTan where import Data.Bifunctor (bimap) -import Array -import AST.Types -import CHAD.Types -import Data -import ForwardAD -import Interpreter.Rep +import CHAD.Array +import CHAD.AST.Types +import CHAD.Data +import CHAD.Drev.Types +import CHAD.ForwardAD +import CHAD.Interpreter.Rep toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) @@ -19,24 +19,25 @@ toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) toTan typ primal der = case typ of STNil -> der - STPair t1 t2 -> case der of - Nothing -> bimap (zeroTan t1) (zeroTan t2) primal - Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal + STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal STEither t1 t2 -> case der of Nothing -> bimap (zeroTan t1) (zeroTan t2) primal Just d -> case (primal, d) of (Left p, Left d') -> Left (toTan t1 p d') (Right p, Right d') -> Right (toTan t2 p d') _ -> error "Primal and cotangent disagree on Either alternative" + STLEither t1 t2 -> case (primal, der) of + (_, Nothing) -> Nothing + (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d)) + (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) + _ -> error "Primal and cotangent disagree on LEither alternative" STMaybe t -> liftA2 (toTan t) primal der - STArr _ t -> case der of - Nothing -> arrayMap (zeroTan t) primal - Just d - | arrayShape primal == arrayShape d -> - arrayGenerateLin (arrayShape primal) $ \i -> - toTan t (arrayIndexLinear primal i) (arrayIndexLinear d i) - | otherwise -> - error "Primal and cotangent disagree on array shape" + STArr _ t + | arrayShape primal == arrayShape der -> + arrayGenerateLin (arrayShape primal) $ \i -> + toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) + | otherwise -> + error "Primal and cotangent disagree on array shape" STScal sty -> case sty of STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der STAccum{} -> error "Accumulators not allowed in input program" diff --git a/src/Example.hs b/src/CHAD/Example.hs index 3623d03..34ff889 100644 --- a/src/Example.hs +++ b/src/CHAD/Example.hs @@ -5,25 +5,46 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} -module Example where -import Array -import AST -import AST.Pretty -import CHAD -import CHAD.Top -import ForwardAD -import Interpreter -import Language -import Simplify +{-# OPTIONS -Wno-unused-imports #-} +module CHAD.Example where import Debug.Trace -import Example.Types + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Pretty +import CHAD.AST.UnMonoid +import CHAD.Data +import CHAD.Drev +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.Example.Types +import CHAD.ForwardAD +import CHAD.Interpreter +import CHAD.Language as L +import CHAD.Simplify -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) +pipeline :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +pipeline config term + | Dict <- styKnown (d2 (typeOf term)) = + simplifyFix $ pruneExpr knownEnv $ + simplifyFix $ unMonoid $ + simplifyFix $ chad' config knownEnv $ + simplifyFix $ term + +-- :seti -XOverloadedLabels -XPartialTypeSignatures -Wno-partial-type-signatures +pipeline' :: KnownEnv env => CHADConfig -> Ex env t -> IO () +pipeline' config term + | Dict <- styKnown (d2 (typeOf term)) = + pprintExpr (pipeline config term) + + bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c bin op a b = EOp ext op (EPair ext a b) @@ -159,8 +180,18 @@ neuralGo = ELet ext (EConst ext STF64 1.0) $ chad defaultConfig knownEnv neural (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of - (primal', (((((), Just (Just dlay1_1'a, Just dlay1_1'b)), Just (Just dlay2_1'a, Just dlay2_1'b)), Just dlay3_1'), Just dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') - _ -> undefined + (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0 in trace (ppExpr knownEnv revderiv) $ (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) + +-- The build body uses free variables in a non-linear way, so their primal +-- values are required in the dual of the build. Thus, compositionally, they +-- are stored in the tape from each individual lambda invocation. This results +-- in n copies of y and z, where only one copy would have sufficed. +exUniformFree :: Ex '[R, I64] R +exUniformFree = fromNamed $ lambda #n $ lambda #x $ body $ + let_ #y (#x * 2) $ + let_ #z (#x * 3) $ + idx0 $ sum1i $ + build1 #n $ #i :-> #y * #z + toFloat_ #i diff --git a/src/Example/GMM.hs b/src/CHAD/Example/GMM.hs index 12bbd98..18641e8 100644 --- a/src/Example/GMM.hs +++ b/src/CHAD/Example/GMM.hs @@ -1,10 +1,11 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE TypeApplications #-} -module Example.GMM where +module CHAD.Example.GMM where -import Example.Types -import Language +import CHAD.Data (SList(..)) +import CHAD.Example.Types +import CHAD.Language @@ -31,10 +32,10 @@ import Language -- <https://tomsmeding.com/f/master.pdf> -- -- The 'wrong' argument, when set to True, changes the objective function to --- one with a bug that makes a certain `build` result unused. This triggers +-- one with a bug that makes a certain `build` result unused. This -- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's -- dense, even though it may be a zero (i.e. empty). The "unused" test in --- test/Main.hs tries to isolate this test, but the wrong version of +-- test/Main.hs tries to isolate this case, but the wrong version of -- gmmObjective is here to check (after that bug is fixed) whether it really -- fixes the original bug. gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R diff --git a/src/Example/Types.hs b/src/CHAD/Example/Types.hs index d63159b..1e2f72d 100644 --- a/src/Example/Types.hs +++ b/src/CHAD/Example/Types.hs @@ -1,8 +1,8 @@ {-# LANGUAGE DataKinds #-} -module Example.Types where +module CHAD.Example.Types where -import AST -import Data +import CHAD.AST +import CHAD.Data type R = TScal TF64 diff --git a/src/ForwardAD.hs b/src/CHAD/ForwardAD.hs index b7036dd..0ae88ce 100644 --- a/src/ForwardAD.hs +++ b/src/CHAD/ForwardAD.hs @@ -4,28 +4,30 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module ForwardAD where +module CHAD.ForwardAD where import Data.Bifunctor (bimap) +import Data.Foldable (fold) import System.IO.Unsafe -- import Debug.Trace --- import AST.Pretty +-- import CHAD.AST.Pretty -import Array -import AST -import Compile -import Data -import ForwardAD.DualNumbers -import Interpreter -import Interpreter.Rep +import CHAD.Array +import CHAD.AST +import CHAD.Compile +import CHAD.Data +import CHAD.ForwardAD.DualNumbers +import CHAD.Interpreter +import CHAD.Interpreter.Rep --- | Tangent along a type (coincides with cotangent for these types) +-- | Tangent along a type (coincides with the cotangent, t'CHAD.Drev.Types.D2', for these types) type family Tan t where Tan TNil = TNil Tan (TPair a b) = TPair (Tan a) (Tan b) Tan (TEither a b) = TEither (Tan a) (Tan b) + Tan (TLEither a b) = TLEither (Tan a) (Tan b) Tan (TMaybe t) = TMaybe (Tan t) Tan (TArr n t) = TArr n (Tan t) Tan (TScal t) = TanS t @@ -45,6 +47,7 @@ tanty :: STy t -> STy (Tan t) tanty STNil = STNil tanty (STPair a b) = STPair (tanty a) (tanty b) tanty (STEither a b) = STEither (tanty a) (tanty b) +tanty (STLEither a b) = STLEither (tanty a) (tanty b) tanty (STMaybe t) = STMaybe (tanty t) tanty (STArr n t) = STArr n (tanty t) tanty (STScal t) = case t of @@ -55,11 +58,18 @@ tanty (STScal t) = case t of STBool -> STNil tanty STAccum{} = error "Accumulators not allowed in input program" +tanenv :: SList STy env -> SList STy (TanE env) +tanenv SNil = SNil +tanenv (t `SCons` env) = tanty t `SCons` tanenv env + zeroTan :: STy t -> Rep t -> Rep (Tan t) zeroTan STNil () = () zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y) zeroTan (STEither a _) (Left x) = Left (zeroTan a x) zeroTan (STEither _ b) (Right y) = Right (zeroTan b y) +zeroTan (STLEither _ _) Nothing = Nothing +zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) +zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y)) zeroTan (STMaybe _) Nothing = Nothing zeroTan (STMaybe t) (Just x) = Just (zeroTan t x) zeroTan (STArr _ t) x = fmap (zeroTan t) x @@ -75,9 +85,12 @@ tanScalars STNil () = [] tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y tanScalars (STEither a _) (Left x) = tanScalars a x tanScalars (STEither _ b) (Right y) = tanScalars b y +tanScalars (STLEither _ _) Nothing = [] +tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x +tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y tanScalars (STMaybe _) Nothing = [] tanScalars (STMaybe t) (Just x) = tanScalars t x -tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x +tanScalars (STArr _ t) x = fold $ arrayMap (tanScalars t) x tanScalars (STScal STI32) _ = [] tanScalars (STScal STI64) _ = [] tanScalars (STScal STF32) x = [realToFrac x] @@ -98,6 +111,10 @@ unzipDN (STPair a b) (d1, d2) = unzipDN (STEither a b) d = case d of Left d1 -> bimap Left Left (unzipDN a d1) Right d2 -> bimap Right Right (unzipDN b d2) +unzipDN (STLEither a b) d = case d of + Nothing -> (Nothing, Nothing) + Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) + Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y) unzipDN (STMaybe t) d = case d of Nothing -> (Nothing, Nothing) Just d' -> bimap Just Just (unzipDN t d') @@ -120,6 +137,12 @@ dotprodTan (STEither a b) x y = case (x, y) of (Left x', Left y') -> dotprodTan a x' y' (Right x', Right y') -> dotprodTan b x' y' _ -> error "dotprodTan: incompatible Either alternatives" +dotprodTan (STLEither a b) x y = case (x, y) of + (Nothing, _) -> 0.0 -- 0 * y = 0 + (_, Nothing) -> 0.0 -- x * 0 = 0 + (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' + (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' + _ -> error "dotprodTan: incompatible LEither alternatives" dotprodTan (STMaybe t) x y = case (x, y) of (Nothing, Nothing) -> 0.0 (Just x', Just y') -> dotprodTan t x' y' @@ -165,6 +188,7 @@ dnConst :: STy t -> Rep t -> Rep (DN t) dnConst STNil = const () dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2) dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2)) dnConst (STMaybe t) = fmap (dnConst t) dnConst (STArr _ t) = arrayMap (dnConst t) dnConst (STScal t) = case t of @@ -188,6 +212,11 @@ dnOnehots (STEither t1 t2) e = case e of Left x -> \f -> Left (dnOnehots t1 x (f . Left)) Right y -> \f -> Right (dnOnehots t2 y (f . Right)) +dnOnehots (STLEither t1 t2) e = + case e of + Nothing -> \_ -> Nothing + Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) + Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right))) dnOnehots (STMaybe t) m = case m of Nothing -> \_ -> Nothing @@ -226,8 +255,10 @@ makeFwdADArtifactInterp env expr = in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr) {-# NOINLINE makeFwdADArtifactCompile #-} -makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t) -makeFwdADArtifactCompile env expr = FwdADArtifact env (typeOf expr) . (unsafePerformIO .) <$> compile (dne env) (dfwdDN expr) +makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t, String) +makeFwdADArtifactCompile env expr = do + (fun, output) <- compile (dne env) (dfwdDN expr) + return (FwdADArtifact env (typeOf expr) (unsafePerformIO . fun), output) drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr) diff --git a/src/ForwardAD/DualNumbers.hs b/src/CHAD/ForwardAD/DualNumbers.hs index 2f94076..540ec2b 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/CHAD/ForwardAD/DualNumbers.hs @@ -1,11 +1,10 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -- I want to bring various type variables in scope using type annotations in @@ -14,14 +13,14 @@ -- useful here. {-# LANGUAGE PartialTypeSignatures #-} {-# OPTIONS -Wno-partial-type-signatures #-} -module ForwardAD.DualNumbers ( +module CHAD.ForwardAD.DualNumbers ( dfwdDN, DN, DNS, DNE, dn, dne, ) where -import AST -import Data -import ForwardAD.DualNumbers.Types +import CHAD.AST +import CHAD.Data +import CHAD.ForwardAD.DualNumbers.Types dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) @@ -143,16 +142,21 @@ dfwdDN = \case ENothing _ t -> ENothing ext (dn t) EJust _ e -> EJust ext (dfwdDN e) EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b) + ELNil _ t1 t2 -> ELNil ext (dn t1) (dn t2) + ELInl _ t e -> ELInl ext (dn t) (dfwdDN e) + ELInr _ t e -> ELInr ext (dn t) (dfwdDN e) + ELCase _ e a b c -> ELCase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) (dfwdDN c) EConstArr _ n t x -> scalTyCase t (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) (EConstArr ext n t x)) (EConstArr ext n t x) EBuild _ n a b | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) + EMap _ a b -> EMap ext (dfwdDN a) (dfwdDN b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c) ESum1Inner _ e -> let STArr n (STScal t) = typeOf e - pairty = (STPair (STScal t) (STScal t)) + pairty = STPair (STScal t) (STScal t) in scalTyCase t (ELet ext (dfwdDN e) $ ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ)) @@ -164,6 +168,9 @@ dfwdDN = \case EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e + EZip _ a b -> EZip ext (dfwdDN a) (dfwdDN b) + EReshape _ n esh e + | Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e) EConst _ t x -> scalTyCase t (EPair ext (EConst ext t x) (EConst ext t 0.0)) (EConst ext t x) @@ -181,16 +188,22 @@ dfwdDN = \case ELet ext (dfwdDN e1) $ ELet ext (weakenExpr WSink (dfwdDN e2)) $ weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr) + ERecompute _ e -> dfwdDN e EError _ t s -> EError ext (dn t) s EWith{} -> err_accum EAccum{} -> err_accum + EDeepZero{} -> err_monoid EZero{} -> err_monoid EPlus{} -> err_monoid EOneHot{} -> err_monoid + + EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" + EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" where err_accum = error "Accumulator operations unsupported in the source program" err_monoid = error "Monoid operations unsupported in the source program" + err_targetlang s = error $ "Target language operation " ++ s ++ " not supported in source program" deriv_extremum :: ScalIsNumeric t ~ True => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) @@ -214,4 +227,4 @@ dfwdDN = \case (EFst ext (EVar ext tIxN (IS IZ))))))) (ESnd ext (EVar ext t2 (IS IZ))) (zeroScalarConst t)))) - (EMaximum1Inner ext (dfwdDN e)) + (extremum (dfwdDN e)) diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/CHAD/ForwardAD/DualNumbers/Types.hs index fba92d0..5d5dd9e 100644 --- a/src/ForwardAD/DualNumbers/Types.hs +++ b/src/CHAD/ForwardAD/DualNumbers/Types.hs @@ -1,10 +1,10 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -module ForwardAD.DualNumbers.Types where +module CHAD.ForwardAD.DualNumbers.Types where -import AST.Types -import Data +import CHAD.AST.Types +import CHAD.Data -- | Dual-numbers transformation @@ -12,6 +12,7 @@ type family DN t where DN TNil = TNil DN (TPair a b) = TPair (DN a) (DN b) DN (TEither a b) = TEither (DN a) (DN b) + DN (TLEither a b) = TLEither (DN a) (DN b) DN (TMaybe t) = TMaybe (DN t) DN (TArr n t) = TArr n (DN t) DN (TScal t) = DNS t @@ -31,6 +32,7 @@ dn :: STy t -> STy (DN t) dn STNil = STNil dn (STPair a b) = STPair (dn a) (dn b) dn (STEither a b) = STEither (dn a) (dn b) +dn (STLEither a b) = STLEither (dn a) (dn b) dn (STMaybe t) = STMaybe (dn t) dn (STArr n t) = STArr n (dn t) dn (STScal t) = case t of diff --git a/src/CHAD/Interpreter.hs b/src/CHAD/Interpreter.hs new file mode 100644 index 0000000..6410b5b --- /dev/null +++ b/src/CHAD/Interpreter.hs @@ -0,0 +1,468 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Interpreter ( + interpret, + interpretOpen, + Value(..), +) where + +import Control.Monad (foldM, join, when, forM_) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict (runStateT, get, put) +import Data.Bifunctor (bimap) +import Data.Bitraversable (bitraverse) +import Data.Char (isSpace) +import Data.Functor.Identity +import Data.Functor.Product qualified as Product +import Data.Int (Int64) +import Data.IORef +import Data.Tuple (swap) +import System.IO (hPutStrLn, stderr) +import System.IO.Unsafe (unsafePerformIO) + +import Debug.Trace + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Interpreter.Rep + + +newtype AcM s a = AcM { unAcM :: IO a } + deriving newtype (Functor, Applicative, Monad) + +runAcM :: (forall s. AcM s a) -> a +runAcM (AcM m) = unsafePerformIO m + +acmDebugLog :: String -> AcM s () +acmDebugLog s = AcM (hPutStrLn stderr s) + +data V t = V (STy t) (Rep t) + +interpret :: Ex '[] t -> Rep t +interpret = interpretOpen False SNil SNil + +-- | Bool: whether to trace execution with debug prints (very verbose) +interpretOpen :: Bool -> SList STy env -> SList Value env -> Ex env t -> Rep t +interpretOpen prints env venv e = + runAcM $ + let ?depth = 0 + ?prints = prints + in interpret' (slistMap (\(Product.Pair t (Value v)) -> V t v) (slistZip env venv)) e + +interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) + => SList V env -> Ex env t -> AcM s (Rep t) +interpret' env e = do + let tenv = slistMap (\(V t _) -> t) env + let dep = ?depth + let lenlimit = max 20 (100 - dep) + let replace a b = map (\c -> if c == a then b else c) + let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..." + | otherwise = replace '\n' ' ' s + when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr tenv e) + res <- let ?depth = dep + 1 in interpret'Rec env e + when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res "" + return res + +interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList V env -> Ex env t -> AcM s (Rep t) +interpret'Rec env = \case + EVar _ _ i -> case slistIdx env i of V _ x -> return x + ELet _ a b -> do + x <- interpret' env a + let ?depth = ?depth - 1 in interpret' (V (typeOf a) x `SCons` env) b + expr | False && trace ("<i> " ++ takeWhile (not . isSpace) (show expr)) False -> undefined + EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b + EFst _ e -> fst <$> interpret' env e + ESnd _ e -> snd <$> interpret' env e + ENil _ -> return () + EInl _ _ e -> Left <$> interpret' env e + EInr _ _ e -> Right <$> interpret' env e + ECase _ e a b -> + let STEither t1 t2 = typeOf e + in interpret' env e >>= \case + Left x -> interpret' (V t1 x `SCons` env) a + Right y -> interpret' (V t2 y `SCons` env) b + ENothing _ _ -> return Nothing + EJust _ e -> Just <$> interpret' env e + EMaybe _ a b e -> + let STMaybe t1 = typeOf e + in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e + ELNil _ _ _ -> return Nothing + ELInl _ _ e -> Just . Left <$> interpret' env e + ELInr _ _ e -> Just . Right <$> interpret' env e + ELCase _ e a b c -> + let STLEither t1 t2 = typeOf e + in interpret' env e >>= \case + Nothing -> interpret' env a + Just (Left x) -> interpret' (V t1 x `SCons` env) b + Just (Right y) -> interpret' (V t2 y `SCons` env) c + EConstArr _ _ _ v -> return v + EBuild _ dim a b -> do + sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a + arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b) + EMap _ a b -> do + let STArr _ t = typeOf b + arrayMapM (\x -> interpret' (V t x `SCons` env) a) =<< interpret' env b + EFold1Inner _ _ a b c -> do + let t = typeOf b + let f = \x -> interpret' (V (STPair t t) x `SCons` env) a + x0 <- interpret' env b + arr <- interpret' env c + let sh `ShCons` n = arrayShape arr + arrayGenerateM sh $ \idx -> foldM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + ESum1Inner _ e -> do + arr <- interpret' env e + let STArr _ (STScal t) = typeOf e + sh `ShCons` n = arrayShape arr + numericIsNum t $ return $ arrayGenerate sh $ \idx -> sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e) + EReplicate1Inner _ a b -> do + n <- fromIntegral @Int64 @Int <$> interpret' env a + arr <- interpret' env b + let sh = arrayShape arr + return $ arrayGenerate (sh `ShCons` n) (\(idx `IxCons` _) -> arrayIndex arr idx) + EMaximum1Inner _ e -> do + arr <- interpret' env e + let STArr _ (STScal t) = typeOf e + sh `ShCons` n = arrayShape arr + numericIsNum t $ return $ + arrayGenerate sh (\idx -> maximum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) + EMinimum1Inner _ e -> do + arr <- interpret' env e + let STArr _ (STScal t) = typeOf e + sh `ShCons` n = arrayShape arr + numericIsNum t $ return $ + arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) + EReshape _ dim esh e -> do + sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env esh + arr <- interpret' env e + return $ arrayReshape sh arr + EZip _ a b -> do + arr1 <- interpret' env a + arr2 <- interpret' env b + let sh = arrayShape arr1 + when (sh /= arrayShape arr2) $ + error "Interpreter: mismatched shapes in EZip" + return $ arrayGenerateLin sh (\i -> (arr1 `arrayIndexLinear` i, arr2 `arrayIndexLinear` i)) + EFold1InnerD1 _ _ a b c -> do + let t = typeOf b + let f = \x -> interpret' (V (STPair t t) x `SCons` env) a + x0 <- interpret' env b + arr <- interpret' env c + let sh `ShCons` n = arrayShape arr + -- TODO: this is very inefficient, even for an interpreter; with mutable + -- arrays this can be a lot better with no lists + res <- arrayGenerateM sh $ \idx -> do + (y, stores) <- mapAccumLM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + return (y, arrayFromList (ShNil `ShCons` n) stores) + return (arrayMap fst res + ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> + arrayIndexLinear (snd (arrayIndex res idx)) i) + EFold1InnerD2 _ _ ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef + bog <- interpret' env ebog + arrctg <- interpret' env ed + let sh `ShCons` n = arrayShape bog + when (sh /= arrayShape arrctg) $ error "Interpreter: mismatched shapes in EFold1InnerD2" + res <- arrayGenerateM sh $ \idx -> do + let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs) + loop i !ctg !inpctgs = do + let b = arrayIndex bog (idx `IxCons` i) + (ctg1, ctg2) <- f b ctg + loop (i - 1) ctg1 (ctg2 : inpctgs) + (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) [] + return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg) + return (arrayMap fst res + ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> + arrayIndexLinear (snd (arrayIndex res idx)) i) + EConst _ _ v -> return v + EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e + EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) + EIdx _ a b -> + let STArr n _ = typeOf a + in arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) + EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e + EOp _ op e -> interpretOp op <$> interpret' env e + ECustom _ t1 t2 _ pr _ _ e1 e2 -> do + e1' <- interpret' env e1 + e2' <- interpret' env e2 + interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr + ERecompute _ e -> interpret' env e + EWith _ t e1 e2 -> do + initval <- interpret' env e1 + withAccum t (typeOf e2) initval $ \accum -> + interpret' (V (STAccum t) accum `SCons` env) e2 + EAccum _ t p e1 sp e2 e3 -> do + idx <- interpret' env e1 + val <- interpret' env e2 + accum <- interpret' env e3 + accumAddSparseD t p accum idx sp val + EZero _ t ezi -> do + zi <- interpret' env ezi + return $ zeroM t zi + EDeepZero _ t ezi -> do + zi <- interpret' env ezi + return $ deepZeroM t zi + EPlus _ t a b -> do + a' <- interpret' env a + b' <- interpret' env b + return $ addM t a' b' + EOneHot _ t p a b -> do + a' <- interpret' env a + b' <- interpret' env b + return $ onehotM p t a' b' + EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s + +interpretOp :: SOp a t -> Rep a -> Rep t +interpretOp op arg = case op of + OAdd st -> numericIsNum st $ uncurry (+) arg + OMul st -> numericIsNum st $ uncurry (*) arg + ONeg st -> numericIsNum st $ negate arg + OLt st -> numericIsNum st $ uncurry (<) arg + OLe st -> numericIsNum st $ uncurry (<=) arg + OEq st -> styIsEq st $ uncurry (==) arg + ONot -> not arg + OAnd -> uncurry (&&) arg + OOr -> uncurry (||) arg + OIf -> if arg then Left () else Right () + ORound64 -> round arg + OToFl64 -> fromIntegral arg + ORecip st -> floatingIsFractional st $ recip arg + OExp st -> floatingIsFractional st $ exp arg + OLog st -> floatingIsFractional st $ log arg + OIDiv st -> integralIsIntegral st $ uncurry quot arg + OMod st -> integralIsIntegral st $ uncurry rem arg + where + styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r + styIsEq STI32 = id + styIsEq STI64 = id + styIsEq STF32 = id + styIsEq STF64 = id + styIsEq STBool = id + +zeroM :: SMTy t -> Rep (ZeroInfo t) -> Rep t +zeroM typ zi = case typ of + SMTNil -> () + SMTPair t1 t2 -> (zeroM t1 (fst zi), zeroM t2 (snd zi)) + SMTLEither _ _ -> Nothing + SMTMaybe _ -> Nothing + SMTArr _ t -> arrayMap (zeroM t) zi + SMTScal sty -> case sty of + STI32 -> 0 + STI64 -> 0 + STF32 -> 0.0 + STF64 -> 0.0 + +deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t +deepZeroM typ zi = case typ of + SMTNil -> () + SMTPair t1 t2 -> (deepZeroM t1 (fst zi), deepZeroM t2 (snd zi)) + SMTLEither t1 t2 -> fmap (bimap (deepZeroM t1) (deepZeroM t2)) zi + SMTMaybe t -> fmap (deepZeroM t) zi + SMTArr _ t -> arrayMap (deepZeroM t) zi + SMTScal sty -> case sty of + STI32 -> 0 + STI64 -> 0 + STF32 -> 0.0 + STF64 -> 0.0 + +addM :: SMTy t -> Rep t -> Rep t -> Rep t +addM typ a b = case typ of + SMTNil -> () + SMTPair t1 t2 -> (addM t1 (fst a) (fst b), addM t2 (snd a) (snd b)) + SMTLEither t1 t2 -> case (a, b) of + (Nothing, _) -> b + (_, Nothing) -> a + (Just (Left x), Just (Left y)) -> Just (Left (addM t1 x y)) + (Just (Right x), Just (Right y)) -> Just (Right (addM t2 x y)) + _ -> error "Plus of inconsistent LEithers" + SMTMaybe t -> case (a, b) of + (Nothing, _) -> b + (_, Nothing) -> a + (Just x, Just y) -> Just (addM t x y) + SMTArr _ t -> + let sh1 = arrayShape a + sh2 = arrayShape b + in if | shapeSize sh1 == 0 -> b + | shapeSize sh2 == 0 -> a + | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addM t (arrayIndexLinear a i) (arrayIndexLinear b i)) + | otherwise -> error "Plus of inconsistently shaped arrays" + SMTScal sty -> numericIsNum sty $ a + b + +onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a +onehotM SAPHere _ _ val = val +onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx)) +onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val) +onehotM (SAPLeft prj) (SMTLEither a _) idx val = Just (Left (onehotM prj a idx val)) +onehotM (SAPRight prj) (SMTLEither _ b) idx val = Just (Right (onehotM prj b idx val)) +onehotM (SAPJust prj) (SMTMaybe a) idx val = Just (onehotM prj a idx val) +onehotM (SAPArrIdx prj) (SMTArr n a) idx val = + runIdentity $ onehotArray (\idx' -> Identity (onehotM prj a idx' val)) (\zi -> Identity (zeroM a zi)) n prj idx + +withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) +withAccum t _ initval f = AcM $ do + accum <- newAcDense t initval + out <- unAcM $ f accum + val <- readAc t accum + return (out, val) + +newAcDense :: SMTy a -> Rep a -> IO (RepAc a) +newAcDense typ val = case typ of + SMTNil -> return () + SMTPair t1 t2 -> bitraverse (newAcDense t1) (newAcDense t2) val + SMTLEither t1 t2 -> newIORef =<< traverse (bitraverse (newAcDense t1) (newAcDense t2)) val + SMTMaybe t1 -> newIORef =<< traverse (newAcDense t1) val + SMTArr _ t1 -> arrayMapM (newAcDense t1) val + SMTScal _ -> newIORef val + +onehotArray :: Monad m + => (Rep (AcIdxS p a) -> m v) -- ^ the "one" + -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero" + -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v) +onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) = + let arrindex = unTupRepIdx IxNil IxCons n arrindex' + arrsh = arrayShape ziarr + !linindex = toLinearIndex arrsh arrindex + in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i)) + +readAc :: SMTy t -> RepAc t -> IO (Rep t) +readAc typ val = case typ of + SMTNil -> return () + SMTPair t1 t2 -> bitraverse (readAc t1) (readAc t2) val + SMTLEither t1 t2 -> traverse (bitraverse (readAc t1) (readAc t2)) =<< readIORef val + SMTMaybe t -> traverse (readAc t) =<< readIORef val + SMTArr _ t -> traverse (readAc t) val + SMTScal _ -> readIORef val + +accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s () +accumAddSparseD typ prj ref idx sp val = case (typ, prj) of + (_, SAPHere) -> accumAddDense typ ref sp val + + (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx sp val + (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx sp val + + (SMTLEither t1 _, SAPLeft prj') -> + realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") + (\case Left ac1 -> accumAddSparseD t1 prj' ac1 idx sp val + Right{} -> error "Mismatched Either in accumAddSparseD (r +l)") + (SMTLEither _ t2, SAPRight prj') -> + realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") + (\case Right ac2 -> accumAddSparseD t2 prj' ac2 idx sp val + Left{} -> error "Mismatched Either in accumAddSparseD (l +r)") + + (SMTMaybe t1, SAPJust prj') -> + realiseMaybeSparse ref (error "Accumulating Just into Nothing (EWith requires EDeepZero)") + (\ac -> accumAddSparseD t1 prj' ac idx sp val) + + (SMTArr n t1, SAPArrIdx prj') -> + let (arrindex', idx') = idx + arrindex = unTupRepIdx IxNil IxCons n arrindex' + arrsh = arrayShape ref + linindex = toLinearIndex arrsh arrindex + in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' sp val + +accumAddDense :: SMTy a -> RepAc a -> Sparse a b -> Rep b -> AcM s () +accumAddDense typ ref sp val = case (typ, sp) of + (_, _) | isAbsent sp -> return () + (_, SpAbsent) -> return () + (_, SpSparse s) -> + case val of + Nothing -> return () + Just val' -> accumAddDense typ ref s val' + (SMTPair t1 t2, SpPair s1 s2) -> do + accumAddDense t1 (fst ref) s1 (fst val) + accumAddDense t2 (snd ref) s2 (snd val) + (SMTLEither t1 t2, SpLEither s1 s2) -> + case val of + Nothing -> return () + Just (Left val1) -> + realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") + (\case Left ac1 -> accumAddDense t1 ac1 s1 val1 + Right{} -> error "Mismatched Either in accumAddSparse (r +l)") + Just (Right val2) -> + realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") + (\case Right ac2 -> accumAddDense t2 ac2 s2 val2 + Left{} -> error "Mismatched Either in accumAddSparse (l +r)") + (SMTMaybe t, SpMaybe s) -> + case val of + Nothing -> return () + Just val' -> + realiseMaybeSparse ref (error "Accumulating Just into Nothing (EAccum requires EDeepZero)") + (\ac -> accumAddDense t ac s val') + (SMTArr _ t1, SpArr s) -> + forM_ [0 .. arraySize ref - 1] $ \i -> + accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i) + (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) + +-- TODO: makeval is always 'error' now. Simplify? +realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s () +realiseMaybeSparse ref makeval modifyval = + -- Try modifying what's already in ref. The 'join' makes the snd + -- of the function's return value a _continuation_ that is run after + -- the critical section ends. + AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of + -- Oops, ref's contents was still sparse. Have to initialise + -- it first, then try again. + Nothing -> (ac, do val <- makeval + join $ atomicModifyIORef' ref $ \ac' -> case ac' of + Nothing -> (Just val, return ()) + Just val' -> (ac', unAcM $ modifyval val')) + -- Yep, ref already had a value in there, so we can just add + -- val' to it recursively. + Just val -> (ac, unAcM $ modifyval val) + + +numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r +numericIsNum STI32 = id +numericIsNum STI64 = id +numericIsNum STF32 = id +numericIsNum STF64 = id + +floatingIsFractional :: ScalIsFloating st ~ True => SScalTy st -> ((Floating (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsFloating st ~ True) => r) -> r +floatingIsFractional STF32 = id +floatingIsFractional STF64 = id + +integralIsIntegral :: ScalIsIntegral st ~ True => SScalTy st -> ((Integral (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsIntegral st ~ True) => r) -> r +integralIsIntegral STI32 = id +integralIsIntegral STI64 = id + +unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) + -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n +unTupRepIdx nil _ SZ _ = nil +unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i + +tupRepIdx :: (forall m. f (S m) -> (f m, Int)) + -> SNat n -> f n -> Rep (Tup (Replicate n TIx)) +tupRepIdx _ SZ _ = () +tupRepIdx uncons (SS n) tup = + let (tup', i) = uncons tup + in ((,) $! tupRepIdx uncons n tup') $! fromIntegral @Int @Int64 i + +ixUncons :: Index (S n) -> (Index n, Int) +ixUncons (IxCons idx i) = (idx, i) + +shUncons :: Shape (S n) -> (Shape n, Int) +shUncons (ShCons idx i) = (idx, i) + +mapAccumLM :: (Traversable t, Monad m) => (s -> a -> m (s, b)) -> s -> t a -> m (s, t b) +mapAccumLM f s0 = fmap swap . flip runStateT s0 . traverse f' + where f' x = do + s <- get + (s', y) <- lift $ f s x + put s' + return y diff --git a/src/Interpreter/Accum.hs b/src/CHAD/Interpreter/Accum.hs index af7be1e..8e5c040 100644 --- a/src/Interpreter/Accum.hs +++ b/src/CHAD/Interpreter/Accum.hs @@ -12,7 +12,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} -module Interpreter.Accum ( +module CHAD.Interpreter.Accum ( AcM, runAcM, Rep', @@ -35,9 +35,9 @@ import GHC.IO (IO(..)) import GHC.Word import System.IO.Unsafe (unsafePerformIO) -import Array -import AST -import Data +import CHAD.Array +import CHAD.AST +import CHAD.Data newtype AcM s a = AcM (IO a) diff --git a/src/Interpreter/AccumOld.hs b/src/CHAD/Interpreter/AccumOld.hs index af7be1e..8e5c040 100644 --- a/src/Interpreter/AccumOld.hs +++ b/src/CHAD/Interpreter/AccumOld.hs @@ -12,7 +12,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnboxedTuples #-} -module Interpreter.Accum ( +module CHAD.Interpreter.Accum ( AcM, runAcM, Rep', @@ -35,9 +35,9 @@ import GHC.IO (IO(..)) import GHC.Word import System.IO.Unsafe (unsafePerformIO) -import Array -import AST -import Data +import CHAD.Array +import CHAD.AST +import CHAD.Data newtype AcM s a = AcM (IO a) diff --git a/src/Interpreter/Rep.hs b/src/CHAD/Interpreter/Rep.hs index be2a4cc..fadc6be 100644 --- a/src/Interpreter/Rep.hs +++ b/src/CHAD/Interpreter/Rep.hs @@ -1,44 +1,41 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -module Interpreter.Rep where +module CHAD.Interpreter.Rep where +import Control.DeepSeq +import Data.Coerce (coerce) import Data.List (intersperse, intercalate) import Data.Foldable (toList) import Data.IORef -import GHC.TypeError +import GHC.Exts (withDict) -import Array -import AST -import AST.Pretty -import Data +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty +import CHAD.Data type family Rep t where Rep TNil = () Rep (TPair a b) = (Rep a, Rep b) Rep (TEither a b) = Either (Rep a) (Rep b) + Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b)) Rep (TMaybe t) = Maybe (Rep t) Rep (TArr n t) = Array n (Rep t) Rep (TScal sty) = ScalRep sty Rep (TAccum t) = RepAc t --- Mutable, represents D2 of t. Has an O(1) zero. +-- Mutable, represents monoid types t. type family RepAc t where RepAc TNil = () - RepAc (TPair a b) = IORef (Maybe (RepAc a, RepAc b)) - RepAc (TEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b))) + RepAc (TPair a b) = (RepAc a, RepAc b) + RepAc (TLEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b))) RepAc (TMaybe t) = IORef (Maybe (RepAc t)) - RepAc (TArr n t) = IORef (Maybe (Array n (RepAc t))) - RepAc (TScal sty) = RepAcScal sty - RepAc (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators") - -type family RepAcScal t where - RepAcScal TI32 = () - RepAcScal TI64 = () - RepAcScal TF32 = IORef Float - RepAcScal TF64 = IORef Double - RepAcScal TBool = () + RepAc (TArr n t) = Array n (RepAc t) + RepAc (TScal sty) = IORef (ScalRep sty) newtype Value t = Value { unValue :: Rep t } @@ -57,8 +54,11 @@ vUnpair (Value (x, y)) = (Value x, Value y) showValue :: Int -> STy t -> Rep t -> ShowS showValue _ STNil () = showString "()" showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")" -showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue 11 a x -showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y +showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Inl " . showValue 11 a x +showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Inr " . showValue 11 b y +showValue _ (STLEither _ _) Nothing = showString "LNil" +showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x +showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y showValue _ (STMaybe _) Nothing = showString "Nothing" showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x showValue d (STArr _ t) arr = showParen (d > 10) $ @@ -66,13 +66,13 @@ showValue d (STArr _ t) arr = showParen (d > 10) $ . showString " [" . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr)) . showString "]" -showValue _ (STScal sty) x = case sty of - STF32 -> shows x - STF64 -> shows x - STI32 -> shows x - STI64 -> shows x - STBool -> shows x -showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSTy 0 t ++ ">" +showValue d (STScal sty) x = case sty of + STF32 -> showsPrec d x + STF64 -> showsPrec d x + STI32 -> showsPrec d x + STI64 -> showsPrec d x + STBool -> showsPrec d x +showValue _ (STAccum t) _ = showString $ "<accumulator for " ++ ppSMTy 0 t ++ ">" showEnv :: SList STy env -> SList Value env -> String showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" @@ -80,3 +80,26 @@ showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" showEntries :: SList STy env -> SList Value env -> [String] showEntries SNil SNil = [] showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs + +rnfRep :: STy t -> Rep t -> () +rnfRep STNil () = () +rnfRep (STPair a b) (x, y) = rnfRep a x `seq` rnfRep b y +rnfRep (STEither a _) (Left x) = rnfRep a x +rnfRep (STEither _ b) (Right y) = rnfRep b y +rnfRep (STLEither _ _) Nothing = () +rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x +rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y +rnfRep (STMaybe _) Nothing = () +rnfRep (STMaybe t) (Just x) = rnfRep t x +rnfRep (STArr (_ :: SNat n) (t :: STy t2)) arr = + withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr) +rnfRep (STScal t) x = case t of + STI32 -> rnf x + STI64 -> rnf x + STF32 -> rnf x + STF64 -> rnf x + STBool -> rnf x +rnfRep STAccum{} _ = error "Cannot rnf accumulators" + +instance KnownTy t => NFData (Value t) where + rnf (Value x) = rnfRep (knownTy @t) x diff --git a/src/CHAD/Language.hs b/src/CHAD/Language.hs new file mode 100644 index 0000000..6621eef --- /dev/null +++ b/src/CHAD/Language.hs @@ -0,0 +1,423 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} +module CHAD.Language ( + -- * Named expressions + fromNamed, + NExpr, NFun, + + -- * Functions + lambda, + body, + inline, + (.$), + + -- * Basic language constructs + let_, + pair, fst_, snd_, nil, + inl, inr, case_, + nothing, just, maybe_, + + -- * Array operations + constArr_, + build1, build2, build, + map_, + fold1i, fold1i', + sum1i, + unit, + replicate1i, + maximum1i, minimum1i, + reshape, + fold1iD1, fold1iD1', + fold1iD2, + + -- * Scalar operations + -- | Note that 'NExpr' is also an instance of some numeric classes like 'Num' and 'Floating'. + const_, + idx0, + (!), + shape, + length_, + error_, + (.==), (.<), (CHAD.Language..>), (.<=), (.>=), + not_, and_, or_, + mod_, round_, toFloat_, idiv, + + -- * Control flow + if_, + + -- * Special operations + custom, + recompute, + with, accum, accumS, + oper, oper2, + + -- * Helper types + (:->)(..), + + -- * Reexports + TIx, + Lookup, + Ex, + Ty(..), + SNat(..), Nat(..), N0, N1, N2, N3, +) where + +import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol) + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Drev.Types +import CHAD.Language.AST + + +-- | Helper type, used for e.g. 'case_' and 'build'. +data a :-> b = a :-> b + deriving (Show) +infixr 0 :-> + + +-- | See 'fromNamed' for a usage example. +body :: NExpr env t -> NFun env env t +body = NBody + +-- | See 'fromNamed' for a usage example. +lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t +lambda = NLam + +-- | Inline a function here, with the given list of expressions as arguments. +-- While this is a normal 'SList', the @params@ list is reversed from the +-- natural argument order of the function; the '(.$)' helper operator serves to +-- "fix" the order. +-- +-- @ +-- let fun = 'lambda' \@(TScal TF64) #x $ 'lambda' \@(TScal TBool) #b $ 'body' $ if_ #b #x (#x + 1) +-- in 'inline' fun ('SNil' .$ 16 .$ 'const_' True) +-- @ +-- +-- Note that no 'const_' is needed for the @16@, because 'NExpr' implements +-- 'Num'. +inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t +inline = inlineNFun + +-- | Helper for constructing the argument list for 'inline'; +-- @(.$) = flip 'SCons'@. See 'inline'. +(.$) :: SList f list -> f a -> SList f (a : list) +(.$) = flip SCons + + +-- | The first 'Var' argument is the left-hand side of this let-binding. For example: +-- +-- @ +-- 'fromNamed' $ 'lambda' \@(TScal TI64) #a $ 'body' $ +-- 'let_' #x (#a + 1) $ +-- #x * #a +-- @ +-- +-- This produces an expression of type @'Ex' '[TScal TI64] (TScal TI64)@ that +-- corresponds to the Haskell code @\\a -> let x = a + 1 in x * a@. +let_ :: forall a t env name. Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t +let_ = NELet + +pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) +pair = NEPair + +fst_ :: NExpr env (TPair a b) -> NExpr env a +fst_ = NEFst + +snd_ :: NExpr env (TPair a b) -> NExpr env b +snd_ = NESnd + +nil :: NExpr env TNil +nil = NENil + +inl :: KnownTy b => NExpr env a -> NExpr env (TEither a b) +inl = NEInl knownTy + +inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b) +inr = NEInr knownTy + +-- | A @case@ expression on @Either@s. For example, the following expression +-- will evaluate to 10 + 1 = 11: +-- +-- @ +-- 'case_' ('inl' 10) +-- (#x :-> #x + 1) +-- (#y :-> #y * 2) +-- @ +case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c +case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2 + +nothing :: KnownTy a => NExpr env (TMaybe a) +nothing = NENothing knownTy + +just :: NExpr env a -> NExpr env (TMaybe a) +just = NEJust + +-- | Analogue of the 'Prelude.maybe' function in the Haskell Prelude: +-- +-- @ +-- 'maybe_' 2 (#x :-> #x * 3) (...) +-- @ +-- +-- will return 2 if @(...)@ is @Nothing@ and @x + 3@ if it is @Just x@. +maybe_ :: NExpr env b -> (Var name a :-> NExpr ('(name, a) : env) b) -> NExpr env (TMaybe a) -> NExpr env b +maybe_ a (v :-> b) c = NEMaybe a v b c + +-- | To construct 'Array' values, see "CHAD.Array". +constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) +constArr_ x = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConstArr knownNat ty x + +-- | Special case of 'build' for 1-dimensional arrays. This produces the array +-- [0.0, 1.0, 2.0]: +-- +-- @ +-- 'build1' 3 (#i :-> 'toFloat_' #i) +-- @ +build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t) +build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b)) + +-- | Special case of 'build' for 2-dimensional arrays. +build2 :: NExpr env TIx -> NExpr env TIx + -> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t) + -> NExpr env (TArr (S (S Z)) t) +build2 a1 a2 (v1 :-> v2 :-> b) = + NEBuild (SS (SS SZ)) + (pair (pair nil a1) a2) + #idx + (let_ v1 (snd_ (fst_ #idx)) $ + let_ v2 (NEDrop SZ (snd_ #idx)) $ + NEDrop (SS (SS SZ)) b) + +-- | General n-dimensional elementwise array constructor. A 3-dimensional index +-- looks like @((((), i1), i2), i3)@; other dimensionalities are analogous. The +-- innermost dimension (i.e. whose index variable varies the fastest in the +-- standard memory layout) is the right-most index, i.e. @i3@ in 3D example. To +-- create a 10-by-10 table of (row, column) pairs: +-- +-- @ +-- 'build' ('SS' ('SS' 'SZ')) ('pair' ('pair' 'nil' 10) 10) (#i :-> #j :-> 'pair' #i #j) +-- @ +build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t) +build n a (v :-> b) = NEBuild n a v b + +map_ :: forall n a b env name. (KnownNat n, KnownTy a) + => (Var name a :-> NExpr ('(name, a) : env) b) + -> NExpr env (TArr n a) -> NExpr env (TArr n b) +map_ (v :-> a) b = NEMap v a b + +-- | Fold over the innermost dimension of an array, thus reducing its dimensionality by one. +fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t t) + in fold1i' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +-- | The underlying AST constructor for a fold takes a function with /one/ +-- argument: a pair of inputs. 'fold1i'' directly returns this AST constructor +-- in case it is helpful for testing. The 'fold1i' function is a convenience +-- wrapper around 'fold1i''. +fold1i' :: (Var name (TPair t t) :-> NExpr ('(name, TPair t t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1i' (v :-> e1) e2 e3 = NEFold1Inner v e1 e2 e3 + +sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +sum1i e = NESum1Inner e + +unit :: NExpr env t -> NExpr env (TArr Z t) +unit = NEUnit + +replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t)) +replicate1i n a = NEReplicate1Inner n a + +maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +maximum1i e = NEMaximum1Inner e + +minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +minimum1i e = NEMinimum1Inner e + +reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) +reshape = NEReshape + +-- | 'fold1iD1'' with a curried combination function. +fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b)) + -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) +fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t1 t1) + in fold1iD1' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +-- | Primal of a fold. Not supported in the input program for reverse differentiation. +fold1iD1' :: (Var name (TPair t1 t1) :-> NExpr ('(name, TPair t1 t1) : env) (TPair t1 b)) + -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) +fold1iD1' (v1 :-> e1) e2 e3 = NEFold1InnerD1 v1 e1 e2 e3 + +-- | Reverse pass of a fold. Not supported in the input program for reverse differentiation. +fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2)) + -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) +fold1iD2 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD2 v1 v2 e1 e2 e3 + +const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t) +const_ x = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConst ty x + +idx0 :: NExpr env (TArr Z t) -> NExpr env t +idx0 = NEIdx0 + +-- (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) +-- (.!) = NEIdx1 +-- infixl 9 .! + +-- | Index an array. Note that the index is a tuple, just like the argument to +-- the function in 'build'. To index a 2-dimensional array @a@ at row @i@ and +-- column @j@, write @a '!' 'pair' ('pair' 'nil' i) j@. +(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t +(!) = NEIdx +infixl 9 ! + +shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) +shape = NEShape + +-- | Convenience special case of 'shape' for single-dimensional arrays. +length_ :: NExpr env (TArr N1 t) -> NExpr env TIx +length_ e = snd_ (shape e) + +oper :: SOp a t -> NExpr env a -> NExpr env t +oper = NEOp + +oper2 :: SOp (TPair a b) t -> NExpr env a -> NExpr env b -> NExpr env t +oper2 op a b = NEOp op (pair a b) + +error_ :: KnownTy t => String -> NExpr env t +error_ s = NEError knownTy s + +-- | Specify a custom reverse derivative for a subexpression. Morally, the type +-- of this combinator should be read as follows: +-- +-- @ +-- custom :: (a -> b -> t) -- normal semantics +-- -> (D1 a -> D1 b -> (D1 t, tape)) -- forward pass +-- -> (tape -> D2 t -> D2 b) -- reverse pass +-- -> a -> b -- arguments +-- -> t -- result +-- @ +-- +-- In normal evaluation, or when forward-differentiating, the first argument is +-- taken and the second and third are ignored. When reverse-differentiating +-- using CHAD, however, the /first/ argument is ignored and the second and +-- third arguments are respectively put in the forward and the reverse passes +-- of the derivative program. The @tape@ value may be used to remember primals +-- for the reverse pass. +-- +-- This combinator allows for "inactive" and "active" inputs to the operation; +-- derivatives to the "inactive" input are not propagated. The active input +-- (whose derivatives /are/ propagated) has type @b@; the inactive input has +-- type @a@. +-- +-- No accumulators are allowed inside @a@, @b@ and @tape@. +custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t) + -> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape)) + -> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b)) + -> NExpr env a -> NExpr env b + -> NExpr env t +custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 = + NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 + +-- | Semantically the identity, but when reverse differentiating using CHAD, +-- the contained expression is recomputed in the reverse pass. This is a +-- light-weight form of checkpointing, with the goal of reducing the number +-- primal values being stored and thus reducing memory use and memory traffic. +-- +-- Note that free variables of the contained expression do still need to be +-- stored, as we do need to be able to recompute the expression in the reverse +-- pass. +recompute :: NExpr env a -> NExpr env a +recompute = NERecompute + +-- | Introduce an accumulator. The initial value is not allowed to be sparse! +-- See 'CHAD.AST.EWith'. Not supported in the input program for reverse +-- differentiation. +with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t) +with a (n :-> b) = NEWith (knownMTy @t) a n b + +-- | Accumulate to an accumulator. Not supported in the input program for +-- reverse differentiation. +accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil +accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c + +-- | Accumulate to an accumulator with additional sparsity. Not supported in +-- the input program for reverse differentiation. +accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil +accumS p a sp b c = NEAccum knownMTy p a sp b c + + +(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +a .== b = oper (OEq knownScalTy) (pair a b) +infix 4 .== + +(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +a .< b = oper (OLt knownScalTy) (pair a b) +infix 4 .< + +(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +(.>) = flip (.<) +infix 4 .> + +(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +a .<= b = oper (OLe knownScalTy) (pair a b) +infix 4 .<= + +(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +(.>=) = flip (.<=) +infix 4 .>= + +not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) +not_ = oper ONot + +and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) +and_ = oper2 OAnd +infixr 3 `and_` + +or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) +or_ = oper2 OOr +infixr 2 `or_` + +mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a) +mod_ = oper2 (OMod knownScalTy) +infixl 7 `mod_` + +-- | The first alternative is the True case; the second is the False case. +if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t +if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b) + +round_ :: NExpr env (TScal TF64) -> NExpr env (TScal TI64) +round_ = oper ORound64 + +toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64) +toFloat_ = oper OToFl64 + +idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t) +idiv = oper2 (OIDiv knownScalTy) +infixl 7 `idiv` diff --git a/src/Language/AST.hs b/src/CHAD/Language/AST.hs index 84544f8..502a2b3 100644 --- a/src/Language/AST.hs +++ b/src/CHAD/Language/AST.hs @@ -4,7 +4,9 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -12,19 +14,22 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -module Language.AST where +module CHAD.Language.AST where import Data.Kind (Type) import Data.Type.Equality import GHC.OverloadedLabels -import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..)) +import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal) -import Array -import AST -import CHAD.Types -import Data +import CHAD.Array +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Drev.Types +-- | A named expression: variables have names, not De Bruijn indices. +-- Otherwise essentially identical to 'Expr'. type NExpr :: [(Symbol, Ty)] -> Ty -> Type data NExpr env t where -- lambda calculus @@ -49,12 +54,23 @@ data NExpr env t where -- array operations NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) - NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) + NEMap :: Var name a -> NExpr ('(name, a) : env) t -> NExpr env (TArr n a) -> NExpr env (TArr n t) + NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEUnit :: NExpr env t -> NExpr env (TArr Z t) NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) + NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) + + NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b) + -> NExpr env t1 + -> NExpr env (TArr (S n) t1) + -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) + NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2) + -> NExpr env (TArr (S n) b) + -> NExpr env (TArr n t2) + -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) -- expression operations NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) @@ -71,9 +87,12 @@ data NExpr env t where -> NExpr env a -> NExpr env b -> NExpr env t + -- fake halfway checkpointing + NERecompute :: NExpr env t -> NExpr env t + -- accumulation effect on monoids - NEWith :: STy t -> NExpr env (D2 t) -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a (D2 t)) - NEAccum :: STy t -> SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env (D2 a) -> NExpr env (TAccum t) -> NExpr env TNil + NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) + NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil -- partiality NEError :: STy a -> String -> NExpr env a @@ -82,11 +101,23 @@ data NExpr env t where NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t deriving instance Show (NExpr env t) -type family Lookup name env where - Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'") - Lookup name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") - Lookup name ('(name, t) : env) = t - Lookup name (_ : env) = Lookup name env +-- | Look up the type of a name in a named environment. +type Lookup name env = Lookup1 (name == "_") name env +-- | This curious stack of type families is used instead of normal pattern +-- matching so the decidable boolean predicate "==" is used. This means that +-- introducing evidence of @(name1 == name2) ~ False@ may allow a certain +-- lookup to reduce even if the names in question are not statically known. +-- This flexibility is used with e.g. 'assertSymbolDistinct' in +-- 'CHAD.Language.fold1i'. +type family Lookup1 eqblank name env where + Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'") + Lookup1 False name env = Lookup2 name env +type family Lookup2 name env where + Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") + Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env +type family Lookup3 eq t name env where + Lookup3 True t _ _ = t + Lookup3 False _ name env = Lookup2 name env type family DropNth i env where DropNth Z (_ : env) = env @@ -138,10 +169,20 @@ data NEnv env where NTop :: NEnv '[] NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env) --- | First (outermost) parameter on the outside, on the left. --- * env: environment of this function (grows as you go deeper inside lambdas) --- * env': environment of the body of the function --- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list +-- | A named /function/. These can be used in only two ways: they can be +-- converted to an unnamed 'Expr' using 'fromNamed', and they can be inlined +-- using 'CHAD.Language.inline'. +-- +-- * @env@: environment of this function (smaller than @env'@; grows as you descend under lambdas) +-- * @env'@: environment of the body of the function +-- +-- For example, a function @(\\(x :: a) (y :: b) -> _ :: c)@ with two free +-- variables, @u :: t1@ and @v :: t2@, would be represented with a value of the +-- following type: +-- +-- @ +-- NFun '['("v", t2), '("u", t1)] '['("y", b), '("x", a), '("v", t2), '("u", t1)] c +-- @ data NFun env env' t where NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t NBody :: NExpr env' t -> NFun env' env' t @@ -157,6 +198,41 @@ envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t inlineNFun fun args = NEUnnamed (fromNamed fun) args +-- | Convert a named function to an unnamed expression with free variables, +-- ready for consumption by the rest of this library. The function must be +-- closed (meaning that the function as a whole cannot have free variables), +-- and the arguments of the function are realised as free variables of the +-- resulting expression. Typical usage looks as follows: +-- +-- @ +-- {-# LANGUAGE OverloadedLabels #-} +-- import CHAD.Language +-- 'fromNamed' $ 'CHAD.Language.lambda' \@(TScal TF64) #x $ 'CHAD.Language.lambda' \@(TScal TI64) #i $ 'CHAD.Language.body' $ #x + 'CHAD.Language.toFloat_' #i +-- :: 'Ex' '[TScal TI64, TScal TF64] (TScal TF64) +-- @ +-- +-- The rest of the library generally considers expressions with free variables +-- to stand in for "functions", by considering the free variables as the +-- function's inputs. +-- +-- Note that while environments normally grow to the right (e.g. in type theory +-- notation), as they as type-level lists here, they grow to the /left/. This +-- is why the second (innermost) argument of the example, @i@, ends up at the +-- head of the environment of the constructed expression. +-- +-- __Type applications__: The type applications to 'CHAD.Language.lambda' above +-- are good practice, but not always necessary; if GHC can infer the type of +-- the argument from the body of the expression, the type application is +-- unnecessary. +-- +-- __Variables__: The major element of syntactic sugar in this module is using +-- OverloadedLabels for variable names. Variables are represented in 'NExpr' +-- (and thus 'NFun') using the 'Var' type; you should never have to manually +-- construct a 'Var'. Instead, 'Var' implements 'IsLabel' and as such can be +-- produced with the syntax @#name@, where "name" is the name of the variable. +-- This syntax produces a polymorphic variable reference whose (embedded) type +-- is left to GHC's type inference engine using a 'KnownTy' constraint. See +-- also 'CHAD.Language.let_'. fromNamed :: NFun '[] env t -> Ex (UnName env) t fromNamed = fromNamedFun NTop @@ -195,12 +271,17 @@ fromNamedExpr val = \case NEConstArr n t x -> EConstArr ext n t x NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) - NEFold1Inner n1 n2 a b c -> EFold1Inner ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) + NEMap n a b -> EMap ext (lambda val n a) (go b) + NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c) NESum1Inner e -> ESum1Inner ext (go e) NEUnit e -> EUnit ext (go e) NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) NEMaximum1Inner e -> EMaximum1Inner ext (go e) NEMinimum1Inner e -> EMinimum1Inner ext (go e) + NEReshape n a b -> EReshape ext n (go a) (go b) + + NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c) + NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) NEConst t x -> EConst ext t x NEIdx0 e -> EIdx0 ext (go e) @@ -215,9 +296,10 @@ fromNamedExpr val = \case (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b) (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c) (go e1) (go e2) + NERecompute e -> ERecompute ext (go e) NEWith t a n b -> EWith ext t (go a) (lambda val n b) - NEAccum t p a b c -> EAccum ext t p (go a) (go b) (go c) + NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c) NEError t s -> EError ext t s @@ -256,3 +338,17 @@ dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env dropNthW SZ (_ `NPush` _) = WSink dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val) dropNthW _ NTop = error "DropNth: index out of range" + +assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r +assertSymbolNotUnderscore s@SSymbol k = + case symbolVal s of + "_" -> error "assertSymbolNotUnderscore: was underscore" + _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k + +assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r +assertSymbolDistinct s1@SSymbol s2@SSymbol k + | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")" + | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k + +equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r +equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k diff --git a/src/Lemmas.hs b/src/CHAD/Lemmas.hs index 31a43ed..55ef042 100644 --- a/src/Lemmas.hs +++ b/src/CHAD/Lemmas.hs @@ -4,7 +4,7 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE AllowAmbiguousTypes #-} -module Lemmas (module Lemmas, (:~:)(Refl)) where +module CHAD.Lemmas (module CHAD.Lemmas, (:~:)(Refl)) where import Data.Type.Equality import Unsafe.Coerce (unsafeCoerce) diff --git a/src/CHAD/Simplify.hs b/src/CHAD/Simplify.hs new file mode 100644 index 0000000..ea253d6 --- /dev/null +++ b/src/CHAD/Simplify.hs @@ -0,0 +1,620 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Simplify ( + simplifyN, simplifyFix, + SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith, +) where + +import Control.Monad (ap) +import Data.Bifunctor (first) +import Data.Function (fix) +import Data.Monoid (Any(..)) + +import Debug.Trace + +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Pretty +import CHAD.AST.Sparse.Types +import CHAD.AST.UnMonoid (acPrjCompose) +import CHAD.Data +import CHAD.Simplify.TH + + +data SimplifyConfig = SimplifyConfig + { scLogging :: Bool + } + +defaultSimplifyConfig :: SimplifyConfig +defaultSimplifyConfig = SimplifyConfig False + +simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t +simplifyN 0 = id +simplifyN n = simplifyN (n - 1) . simplify + +simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t +simplify = + let ?accumInScope = checkAccumInScope @env knownEnv + ?config = defaultSimplifyConfig + in snd . runSM . simplify' + +simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t +simplifyWith config = + let ?accumInScope = checkAccumInScope @env knownEnv + ?config = config + in snd . runSM . simplify' + +simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t +simplifyFix = simplifyFixWith defaultSimplifyConfig + +simplifyFixWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t +simplifyFixWith config = + let ?accumInScope = checkAccumInScope @env knownEnv + ?config = config + in fix $ \loop e -> + let (act, e') = runSM (simplify' e) + in if act then loop e' else e' + +-- | simplify monad +newtype SM tenv tt env t a = SM ((Ex env t -> Ex tenv tt) -> (Any, a)) + deriving (Functor) + +instance Applicative (SM tenv tt env t) where + pure x = SM (\_ -> (Any False, x)) + (<*>) = ap + +instance Monad (SM tenv tt env t) where + SM f >>= g = SM $ \ctx -> f ctx >>= \x -> let SM h = g x in h ctx + +runSM :: SM env t env t a -> (Bool, a) +runSM (SM f) = first getAny (f id) + +smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt) +smReconstruct core = SM (\ctx -> (Any False, ctx core)) + +class Monad m => ActedMonad m where + tellActed :: m () + hideActed :: m a -> m a + liftActed :: (Any, a) -> m a + +instance ActedMonad ((,) Any) where + tellActed = (Any True, ()) + hideActed (_, x) = (Any False, x) + liftActed = id + +instance ActedMonad (SM tenv tt env t) where + tellActed = SM (\_ -> tellActed) + hideActed (SM f) = SM (\ctx -> hideActed (f ctx)) + liftActed pair = SM (\_ -> pair) + +-- more convenient in practice +acted :: ActedMonad m => m a -> m a +acted m = tellActed >> m + +within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a +within subctx (SM f) = SM $ \ctx -> f (ctx . subctx) + +simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) +simplify' expr + | scLogging ?config = do + res <- simplify'Rec expr + full <- smReconstruct res + let printed = ppExpr knownEnv full + replace a bs = concatMap (\x -> if x == a then bs else [x]) + str | '\n' `elem` printed = "--- simplify step:\n " ++ replace '\n' "\n " printed + | otherwise = "--- simplify step: " ++ printed + traceM str + return res + | otherwise = simplify'Rec expr + +simplify'Rec :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) +simplify'Rec = \case + -- inlining + ELet _ rhs body + | cheapExpr rhs + -> acted $ simplify' (substInline rhs body) + + | Occ lexOcc runOcc <- occCount IZ body + , ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One) -- without effects, normal rules apply + || (lexOcc == One && runOcc == One) -- with effects, linear inlining is still allowed, but weakening is not + -> acted $ simplify' (substInline rhs body) + + -- let splitting / let peeling + ELet _ (EPair _ a b) body -> + acted $ simplify' $ + ELet ext a $ + ELet ext (weakenExpr WSink b) $ + subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ) + IS i -> EVar ext t (IS (IS i))) + body + ELet _ (EJust _ a) body -> + acted $ simplify' $ ELet ext a $ subst0 (EJust ext (EVar ext (typeOf a) IZ)) body + ELet _ (EInl _ t2 a) body -> + acted $ simplify' $ ELet ext a $ subst0 (EInl ext t2 (EVar ext (typeOf a) IZ)) body + ELet _ (EInr _ t1 a) body -> + acted $ simplify' $ ELet ext a $ subst0 (EInr ext t1 (EVar ext (typeOf a) IZ)) body + + -- let rotation + ELet _ (ELet _ rhs a) b -> do + b' <- within (ELet ext (ELet ext rhs a)) $ simplify' b + acted $ simplify' $ + ELet ext rhs $ + ELet ext a $ + weakenExpr (WCopy WSink) b' + + -- beta rules for products + EFst _ (EPair _ e e') + | not (hasAdds e') -> acted $ simplify' e + | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e) + ESnd _ (EPair _ e' e) + | not (hasAdds e') -> acted $ simplify' e + | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e) + + -- beta rules for coproducts + ECase _ (EInl _ _ e) rhs _ -> acted $ simplify' (ELet ext e rhs) + ECase _ (EInr _ _ e) _ rhs -> acted $ simplify' (ELet ext e rhs) + + -- beta rules for maybe + EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1 + EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1 + + -- let floating + EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body)) + ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body)) + ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) + EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body)) + EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) + EAccum _ t p e1 sp (ELet _ rhs body) acc -> + acted $ simplify' $ + ELet ext rhs $ + EAccum ext t p (weakenExpr WSink e1) sp body (weakenExpr WSink acc) + + -- let () = e in () ~> e + ELet _ e1 (ENil _) | STNil <- typeOf e1 -> + acted $ simplify' e1 + + -- map (\_ -> x) e ~> build (shape e) (\_ -> x) + EMap _ e1 e2 + | Occ Zero Zero <- occCount IZ e1 + , STArr n _ <- typeOf e2 -> + acted $ simplify' $ + EBuild ext n (EShape ext e2) $ + subst (\_ t' -> \case IZ -> error "Unused variable was used" + IS i -> EVar ext t' (IS i)) + e1 + + -- vertical fusion + EMap _ e1 (EMap _ e2 e3) -> + acted $ simplify' $ + EMap ext (ELet ext e2 (weakenExpr (WCopy WSink) e1)) e3 + + -- projection down-commuting + EFst _ (ECase _ e1 e2 e3) -> + acted $ simplify' $ + ECase ext e1 (EFst ext e2) (EFst ext e3) + ESnd _ (ECase _ e1 e2 e3) -> + acted $ simplify' $ + ECase ext e1 (ESnd ext e2) (ESnd ext e3) + EFst _ (EMaybe _ e1 e2 e3) -> + acted $ simplify' $ + EMaybe ext (EFst ext e1) (EFst ext e2) e3 + ESnd _ (EMaybe _ e1 e2 e3) -> + acted $ simplify' $ + EMaybe ext (ESnd ext e1) (ESnd ext e2) e3 + + -- TODO: more array indexing + EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1), not (hasAdds e2) -> acted $ simplify' $ elet e3 e2 + EIdx _ (EMap _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet (EIdx ext e2 e3) e1 + EIdx _ (EReplicate1Inner _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ EIdx ext e2 (EFst ext e3) + EIdx _ (EUnit _ e1) e2 | not (hasAdds e2) -> acted $ simplify' $ e1 + + -- TODO: more array shape + EShape _ (EBuild _ _ e1 e2) | not (hasAdds e2) -> acted $ simplify' e1 + EShape _ (EMap _ e1 e2) | not (hasAdds e1) -> acted $ simplify' (EShape ext e2) + EShape _ (EReplicate1Inner _ en earr) -> acted $ simplify' (EPair ext (EShape ext earr) en) + + -- TODO: more constant folding + EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) + EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext)) + + -- inline cheap array constructors + ELet _ (EReplicate1Inner _ e1 e2) e3 -> + acted $ simplify' $ + ELet ext (EPair ext e1 e2) $ + let v = EVar ext (STPair tIx (typeOf e2)) IZ + in subst0 (EReplicate1Inner ext (EFst ext v) (ESnd ext v)) e3 + -- -- TODO: This is a bad idea and anyway only helps in practice if (!) is + -- -- cheap, which it can't be because (!) is not cheap if you do AD after. + -- -- Should do proper SoA representation. + -- ELet _ (EBuild _ n e1 e2) e3 | cheapExpr e2 -> + -- acted $ simplify' $ + -- ELet ext e1 $ + -- subst0 (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) IZ) (weakenExpr (WCopy WSink) e2)) e3 + + -- eta rule for unit + e | STNil <- typeOf e, not ?accumInScope || not (hasAdds e) -> + case e of + ENil _ -> return e + _ -> acted $ return (ENil ext) + + EBuild _ SZ _ e -> + acted $ simplify' $ EUnit ext (substInline (ENil ext) e) + + -- monoid rules + EAccum _ t p e1 sp e2 acc -> do + e1' <- within (\e1' -> EAccum ext t p e1' sp e2 acc ) $ simplify' e1 + e2' <- within (\e2' -> EAccum ext t p e1' sp e2' acc ) $ simplify' e2 + acc' <- within (\acc' -> EAccum ext t p e1' sp e2' acc') $ simplify' acc + simplifyOHT (OneHotTerm SAID t p e1' sp e2') + (acted $ return (ENil ext)) + (\sp' (InContext w wrap e) -> do + e' <- within (\e' -> wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')) $ simplify' e + return (wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc'))) + (\(InContext w wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> do + -- The acted management here is a hideous mess. + e1''' <- hideActed $ within (\e1''' -> wrap $ EAccum ext t' p' e1''' sp' e2'' (weakenExpr w acc')) $ simplify' e1'' + e2''' <- hideActed $ within (\e2''' -> wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')) $ simplify' e2'' + return (wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc'))) + EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e + EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e + EOneHot _ t p e1 e2 -> do + e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1 + e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2 + simplifyOHT (OneHotTerm SAIS t p e1' (spDense (acPrjTy p t)) e2') + (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2))) + (\sp' (InContext _ wrap e) -> + case isDense t sp' of + Just Refl -> do + e' <- hideActed $ within wrap $ simplify' e + return (wrap e') + Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") + (\(InContext _ wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> + case isDense (acPrjTy p' t') sp' of + Just Refl -> do + e1''' <- hideActed $ within (\e1''' -> wrap $ EOneHot ext t' p' e1''' e2'') $ simplify' e1'' + e2''' <- hideActed $ within (\e2''' -> wrap $ EOneHot ext t' p' e1''' e2''') $ simplify' e2'' + return (wrap $ EOneHot ext t' p' e1''' e2''') + Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") + + -- type-specific equations for plus + EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) -> + acted $ return (ENil ext) + + EPlus _ (SMTPair t1 t2) (EPair _ a1 b1) (EPair _ a2 b2) -> + acted $ simplify' $ EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2) + + EPlus _ (SMTLEither t1 _) (ELInl _ dt2 a1) (ELInl _ _ a2) -> + acted $ simplify' $ ELInl ext dt2 (EPlus ext t1 a1 a2) + EPlus _ (SMTLEither _ t2) (ELInr _ dt1 b1) (ELInr _ _ b2) -> + acted $ simplify' $ ELInr ext dt1 (EPlus ext t2 b1 b2) + EPlus _ SMTLEither{} ELNil{} e -> acted $ simplify' e + EPlus _ SMTLEither{} e ELNil{} -> acted $ simplify' e + + EPlus _ (SMTMaybe t) (EJust _ e1) (EJust _ e2) -> + acted $ simplify' $ EJust ext (EPlus ext t e1 e2) + EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e + EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e + + -- fallback recursion + EVar _ t i -> pure $ EVar ext t i + ELet _ a b -> [simprec| ELet ext *a *b |] + EPair _ a b -> [simprec| EPair ext *a *b |] + EFst _ e -> [simprec| EFst ext *e |] + ESnd _ e -> [simprec| ESnd ext *e |] + ENil _ -> pure $ ENil ext + EInl _ t e -> [simprec| EInl ext t *e |] + EInr _ t e -> [simprec| EInr ext t *e |] + ECase _ e a b -> [simprec| ECase ext *e *a *b |] + ENothing _ t -> pure $ ENothing ext t + EJust _ e -> [simprec| EJust ext *e |] + EMaybe _ a b e -> [simprec| EMaybe ext *a *b *e |] + ELNil _ t1 t2 -> pure $ ELNil ext t1 t2 + ELInl _ t e -> [simprec| ELInl ext t *e |] + ELInr _ t e -> [simprec| ELInr ext t *e |] + ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |] + EConstArr _ n t v -> pure $ EConstArr ext n t v + EBuild _ n a b -> [simprec| EBuild ext n *a *b |] + EMap _ a b -> [simprec| EMap ext *a *b |] + EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |] + ESum1Inner _ e -> [simprec| ESum1Inner ext *e |] + EUnit _ e -> [simprec| EUnit ext *e |] + EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |] + EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |] + EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |] + EReshape _ n a b -> [simprec| EReshape ext n *a *b |] + EZip _ a b -> [simprec| EZip ext *a *b |] + EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |] + EFold1InnerD2 _ cm a b c -> [simprec| EFold1InnerD2 ext cm *a *b *c |] + EConst _ t v -> pure $ EConst ext t v + EIdx0 _ e -> [simprec| EIdx0 ext *e |] + EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |] + EIdx _ a b -> [simprec| EIdx ext *a *b |] + EShape _ e -> [simprec| EShape ext *e |] + EOp _ op e -> [simprec| EOp ext op *e |] + ECustom _ s t p a b c e1 e2 -> do + a' <- within (\a' -> ECustom ext s t p a' b c e1 e2) (let ?accumInScope = False in simplify' a) + b' <- within (\b' -> ECustom ext s t p a' b' c e1 e2) (let ?accumInScope = False in simplify' b) + c' <- within (\c' -> ECustom ext s t p a' b' c' e1 e2) (let ?accumInScope = False in simplify' c) + e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1) + e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2) + pure (ECustom ext s t p a' b' c' e1' e2') + ERecompute _ e -> [simprec| ERecompute ext *e |] + EWith _ t e1 e2 -> do + e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1) + e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2) + pure (EWith ext t e1' e2') + -- EOneHot _ t p e1 e2 -> [simprec| EOneHot ext t p *e1 *e2 |] + -- EAccum _ t p e1 sp e2 acc -> [simprec| EAccum ext t p *e1 sp *e2 *acc |] + EZero _ t e -> [simprec| EZero ext t *e |] + EDeepZero _ t e -> [simprec| EDeepZero ext t *e |] + EPlus _ t a b -> [simprec| EPlus ext t *a *b |] + EError _ t s -> pure $ EError ext t s + +-- | This can be made more precise by tracking (and not counting) adds on +-- locally eliminated accumulators. +hasAdds :: Expr x env t -> Bool +hasAdds = \case + EVar _ _ _ -> False + ELet _ rhs body -> hasAdds rhs || hasAdds body + EPair _ a b -> hasAdds a || hasAdds b + EFst _ e -> hasAdds e + ESnd _ e -> hasAdds e + ENil _ -> False + EInl _ _ e -> hasAdds e + EInr _ _ e -> hasAdds e + ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b + ENothing _ _ -> False + EJust _ e -> hasAdds e + EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e + ELNil _ _ _ -> False + ELInl _ _ e -> hasAdds e + ELInr _ _ e -> hasAdds e + ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c + EConstArr _ _ _ _ -> False + EBuild _ _ a b -> hasAdds a || hasAdds b + EMap _ a b -> hasAdds a || hasAdds b + EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c + ESum1Inner _ e -> hasAdds e + EUnit _ e -> hasAdds e + EReplicate1Inner _ a b -> hasAdds a || hasAdds b + EMaximum1Inner _ e -> hasAdds e + EMinimum1Inner _ e -> hasAdds e + EReshape _ _ a b -> hasAdds a || hasAdds b + EZip _ a b -> hasAdds a || hasAdds b + EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c + EFold1InnerD2 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c + ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e + EConst _ _ _ -> False + EIdx0 _ e -> hasAdds e + EIdx1 _ a b -> hasAdds a || hasAdds b + EIdx _ a b -> hasAdds a || hasAdds b + EShape _ e -> hasAdds e + EOp _ _ e -> hasAdds e + EWith _ _ a b -> hasAdds a || hasAdds b + ERecompute _ e -> hasAdds e + EAccum _ _ _ _ _ _ _ -> True + EZero _ _ e -> hasAdds e + EDeepZero _ _ e -> hasAdds e + EPlus _ _ a b -> hasAdds a || hasAdds b + EOneHot _ _ _ a b -> hasAdds a || hasAdds b + EError _ _ _ -> False + +checkAccumInScope :: SList STy env -> Bool +checkAccumInScope = \case SNil -> False + SCons t env -> check t || checkAccumInScope env + where + check :: STy t -> Bool + check STNil = False + check (STPair s t) = check s || check t + check (STEither s t) = check s || check t + check (STLEither s t) = check s || check t + check (STMaybe t) = check t + check (STArr _ t) = check t + check (STScal _) = False + check STAccum{} = True + +data OneHotTerm dense env a where + OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a +deriving instance Show (OneHotTerm dense env a) + +data InContext f env (a :: Ty) where + InContext :: env :> env' -> (forall t. Ex env' t -> Ex env t) -> f env' a -> InContext f env a + +simplifyOHT_recogniseMonoid :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) +simplifyOHT_recogniseMonoid (OneHotTerm dense t prj idx sp val) = do + val' <- liftActed $ recogniseMonoid (applySparse sp (acPrjTy prj t)) val + return $ OneHotTerm dense t prj idx sp val' + +simplifyOHT_unsparse :: ActedMonad m => OneHotTerm dense env a -> m (InContext (OneHotTerm dense) env a) +simplifyOHT_unsparse (OneHotTerm SAID t prj1 idx1 sp1 val1) = + unsparseOneHotD sp1 val1 $ \w wrap prj2 idx2 sp2 val2 -> + acPrjCompose SAID prj1 (weakenExpr w idx1) prj2 idx2 $ \prj' idx' -> + return $ InContext w wrap (OneHotTerm SAID t prj' idx' sp2 val2) +simplifyOHT_unsparse oht@(OneHotTerm SAIS _ _ _ _ _) = return $ InContext WId id oht + +simplifyOHT_concat :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) +simplifyOHT_concat (OneHotTerm @dense @_ @_ @_ @env dense t1 prj1 idx1 sp (EOneHot @_ @c @p2 _ t2 prj2 idx2 val)) + | Just Refl <- isDense (acPrjTy prj1 t1) sp = + let idx2' :: Ex env (AcIdx dense p2 c) + idx2' = case dense of + SAID -> reduceAcIdx t2 prj2 idx2 + SAIS -> idx2 + in acPrjCompose dense prj1 idx1 prj2 idx2' $ \prj' idx' -> + acted $ return $ OneHotTerm dense t1 prj' idx' (spDense (acPrjTy prj' t1)) val +simplifyOHT_concat oht = return oht + +-- -- Property not expressed in types: if the Sparse in the input OneHotTerm is +-- -- dense, then the Sparse in the output will also be dense. This property is +-- -- used when simplifying EOneHot, which cannot represent sparsity. +simplifyOHT :: ActedMonad m => OneHotTerm dense env a + -> m r -- ^ Zero case (onehot is actually zero) + -> (forall b. Sparse a b -> InContext Ex env b -> m r) -- ^ Trivial case (no zeros in onehot) + -> (InContext (OneHotTerm dense) env a -> m r) -- ^ Simplified + -> m r +simplifyOHT oht kzero ktriv k = do + -- traceM $ "sOHT: input " ++ show oht + oht1 <- simplifyOHT_recogniseMonoid oht + -- traceM $ "sOHT: recog " ++ show oht1 + InContext w1 wrap1 oht2 <- simplifyOHT_unsparse oht1 + -- traceM $ "sOHT: unspa " ++ show oht2 + oht3 <- simplifyOHT_concat oht2 + -- traceM $ "sOHT: conca " ++ show oht3 + -- traceM "" + case oht3 of + OneHotTerm _ _ _ _ _ EZero{} -> kzero + OneHotTerm _ _ SAPHere _ sp val -> ktriv sp (InContext w1 wrap1 val) + _ -> k (InContext w1 wrap1 oht3) + +-- Sets the acted flag whenever a non-trivial projection is returned or the +-- output Sparse is different from the input Sparse. +unsparseOneHotD :: ActedMonad m => Sparse a a' -> Ex env a' + -> (forall p b c env'. env :> env' -> (forall s. Ex env' s -> Ex env s) + -> SAcPrj p a b -> Ex env' (AcIdxD p a) -> Sparse b c -> Ex env' c -> m r) -> m r +unsparseOneHotD topsp topval k = case (topsp, topval) of + -- eliminate always-Just sparse onehot + (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> + acted $ unsparseOneHotD s (EOneHot ext t prj idx val) k + + -- expand the top levels of a onehot for a sparse type into a onehot for the + -- corresponding non-sparse type + (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj (efst idx) val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPFst spprj) idx' s1' e' + (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> + unsparseOneHotD s2 (EOneHot ext t2 prj (esnd idx) val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPSnd spprj) idx' s1' e' + (SpLEither s1 _, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPLeft spprj) idx' s1' e' + (SpLEither _ s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> + unsparseOneHotD s2 (EOneHot ext t2 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPRight spprj) idx' s1' e' + (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPJust spprj) idx' s1' e' + (SpArr s1, EOneHot _ (SMTArr _ t1) (SAPArrIdx prj) idx val) + | Dict <- styKnown (typeOf idx) -> + unsparseOneHotD s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \w wrap spprj idx' s1' e' -> + acted $ k (w .> WSink) (elet idx . wrap) (SAPArrIdx spprj) (EPair ext (efst (efst (evar (w @> IZ)))) idx') s1' e' + + -- anything else we don't know how to improve + _ -> k WId id SAPHere (ENil ext) topsp topval + +{- +unsparseOneHotS :: ActedMonad m + => Sparse a a' -> Ex env a' + -> (forall b. Sparse a b -> Ex env b -> m r) -> m r +unsparseOneHotS topsp topval k = case (topsp, topval) of + -- order is relevant to make sure we set the acted flag correctly + (SpAbsent, v@ENil{}) -> k SpAbsent v + (SpAbsent, v@EZero{}) -> k SpAbsent v + (SpAbsent, _) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + (_, EZero{}) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + (sp, _) | isAbsent sp -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + + -- the unsparsifying + (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> + acted $ unsparseOneHotS s (EOneHot ext t prj idx val) k + + -- recursion + -- TODO: coproducts could safely become projections as they do not need + -- zeroinfo. But that would only work if the coproduct is at the top, because + -- as soon as we hit a product, we need zeroinfo to make it a projection and + -- we don't have that. + (SpSparse s, e) -> k (SpSparse s) e + (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj (efst idx) val) $ \s1' e' -> + acted $ k (SpPair s1' SpAbsent) (EPair ext e' (ENil ext)) + (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> + unsparseOneHotS s2 (EOneHot ext t2 prj (esnd idx) val) $ \s2' e' -> + acted $ k (SpPair SpAbsent s2') (EPair ext (ENil ext) e') + (SpLEither s1 s2, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> do + case s2 of SpAbsent -> pure () ; _ -> tellActed + k (SpLEither s1' SpAbsent) (ELInl ext STNil e') + (SpLEither s1 s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> + unsparseOneHotS s2 (EOneHot ext t2 prj idx val) $ \s2' e' -> do + case s1 of SpAbsent -> pure () ; _ -> tellActed + acted $ k (SpLEither SpAbsent s2') (ELInr ext STNil e') + (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> + k (SpMaybe s1') (EJust ext e') + (SpArr s1, EOneHot _ (SMTArr n t1) (SAPArrIdx prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \s1' e' -> + k (SpArr s1') (elet idx $ EOneHot ext (SMTArr n (applySparse s1' _)) (SAPArrIdx SAPHere) (EPair ext (efst (evar IZ)) (ENil ext)) e') + _ -> _ +-} + +-- | Recognises 'EZero' and 'EOneHot'. +recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t) +recogniseMonoid _ e@EOneHot{} = return e +recogniseMonoid SMTNil (ENil _) = acted $ return $ EZero ext SMTNil (ENil ext) +recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) = + ((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case + (EZero _ _ ezi1, EZero _ _ ezi2) -> acted $ return $ EZero ext typ (EPair ext ezi1 ezi2) + (a', EZero _ _ ezi2) -> acted $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a' + (EZero _ _ ezi1, b') -> acted $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b' + (a', b') -> return $ EPair ext a' b' +recogniseMonoid typ@(SMTLEither t1 t2) expr = + case expr of + ELNil{} -> acted $ return $ EZero ext typ (ENil ext) + ELInl _ _ e -> acted $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e + ELInr _ _ e -> acted $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e + _ -> return expr +recogniseMonoid typ@(SMTMaybe t1) expr = + case expr of + ENothing{} -> acted $ return $ EZero ext typ (ENil ext) + EJust _ e -> acted $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e + _ -> return expr +recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) = + acted $ do + e' <- recogniseMonoid t e + return $ + ELet ext e' $ + EOneHot ext typ (SAPArrIdx SAPHere) + (EPair ext (EPair ext (ENil ext) (EUnit ext (makeZeroInfo t (EVar ext (fromSMTy t) IZ)))) + (ENil ext)) + (EVar ext (fromSMTy t) IZ) +recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of + (STI32, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STI64, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STF32, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STF64, 0) -> acted $ return $ EZero ext typ (ENil ext) + _ -> return e +recogniseMonoid _ e = return e + +reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdxS p a) -> Ex env (AcIdxD p a) +reduceAcIdx topty topprj e = case (topty, topprj) of + (_, SAPHere) -> ENil ext + (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e) + (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e) + (SMTLEither t1 _ , SAPLeft p) -> reduceAcIdx t1 p e + (SMTLEither _ t2, SAPRight p) -> reduceAcIdx t2 p e + (SMTMaybe t1, SAPJust p) -> reduceAcIdx t1 p e + (SMTArr _ t, SAPArrIdx p) -> + eunPair e $ \_ e1 e2 -> + EPair ext (efst e1) (reduceAcIdx t p e2) + +zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) +zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e) + where + -- invariant: AcIdx expression is duplicable + go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) + go t SAPHere _ e = makeZeroInfo t e + go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx) + go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e) + go SMTLEither{} _ _ _ = ENil ext + go SMTMaybe{} _ _ _ = ENil ext + go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx) diff --git a/src/CHAD/Simplify/TH.hs b/src/CHAD/Simplify/TH.hs new file mode 100644 index 0000000..4af5394 --- /dev/null +++ b/src/CHAD/Simplify/TH.hs @@ -0,0 +1,80 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module CHAD.Simplify.TH (simprec) where + +import Data.Bifunctor (first) +import Data.Char +import Data.List (foldl', foldl1') +import Language.Haskell.TH +import Language.Haskell.TH.Quote +import Text.ParserCombinators.ReadP + + +-- [simprec| EPair ext *a *b |] +-- ~> +-- do a' <- within (\a' -> EPair ext a' b) (simplify' a) +-- b' <- within (\b' -> EPair ext a' b') (simplify' b) +-- pure (EPair ext a' b') + +simprec :: QuasiQuoter +simprec = QuasiQuoter + { quoteDec = \_ -> fail "simprec used outside of expression context" + , quoteType = \_ -> fail "simprec used outside of expression context" + , quoteExp = handler + , quotePat = \_ -> fail "simprec used outside of expression context" + } + +handler :: String -> Q Exp +handler str = + case readP_to_S pTemplate str of + [(template, "")] -> generate template + _:_:_ -> fail "simprec: template grammar ambiguous" + _ -> fail "simprec: could not parse template" + +generate :: Template -> Q Exp +generate (Template topitems) = + let takePrefix (Plain x : xs) = first (x:) (takePrefix xs) + takePrefix xs = ([], xs) + + itemVar "" = error "simprec: empty item name?" + itemVar name@(c:_) | isLower c = VarE (mkName name) + | isUpper c = ConE (mkName name) + | otherwise = error "simprec: non-letter item name?" + + loop :: Exp -> [Item] -> Q [Stmt] + loop yet [] = return [NoBindS (VarE 'pure `AppE` yet)] + loop yet (Plain x : xs) = loop (yet `AppE` itemVar x) xs + loop yet (Recurse x : xs) = do + primeName <- newName (x ++ "'") + let appPrePrime e (Plain y) = e `AppE` itemVar y + appPrePrime e (Recurse y) = e `AppE` itemVar y + let stmt = BindS (VarP primeName) $ + VarE (mkName "within") + `AppE` LamE [VarP primeName] (foldl' appPrePrime (yet `AppE` VarE primeName) xs) + `AppE` (VarE (mkName "simplify'") `AppE` VarE (mkName x)) + stmts <- loop (yet `AppE` VarE primeName) xs + return (stmt : stmts) + + (prefix, items') = takePrefix topitems + in DoE Nothing <$> loop (foldl1' AppE (map itemVar prefix)) items' + +data Template = Template [Item] + deriving (Show) + +data Item = Plain String | Recurse String + deriving (Show) + +pTemplate :: ReadP Template +pTemplate = do + items <- many (skipSpaces >> pItem) + skipSpaces + eof + return (Template items) + +pItem :: ReadP Item +pItem = (char '*' >> Recurse <$> pName) +++ (Plain <$> pName) + +pName :: ReadP String +pName = do + c1 <- satisfy (\c -> isAlpha c || c == '_') + cs <- munch (\c -> isAlphaNum c || c `elem` "_'") + return (c1:cs) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs deleted file mode 100644 index 7f49cef..0000000 --- a/src/CHAD/Types.hs +++ /dev/null @@ -1,108 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -module CHAD.Types where - -import AST.Types -import Data - - -type family D1 t where - D1 TNil = TNil - D1 (TPair a b) = TPair (D1 a) (D1 b) - D1 (TEither a b) = TEither (D1 a) (D1 b) - D1 (TMaybe a) = TMaybe (D1 a) - D1 (TArr n t) = TArr n (D1 t) - D1 (TScal t) = TScal t - -type family D2 t where - D2 TNil = TNil - D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) - D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b)) - D2 (TMaybe t) = TMaybe (D2 t) - D2 (TArr n t) = TMaybe (TArr n (D2 t)) - D2 (TScal t) = D2s t - -type family D2s t where - D2s TI32 = TNil - D2s TI64 = TNil - D2s TF32 = TScal TF32 - D2s TF64 = TScal TF64 - D2s TBool = TNil - -type family D1E env where - D1E '[] = '[] - D1E (t : env) = D1 t : D1E env - -type family D2E env where - D2E '[] = '[] - D2E (t : env) = D2 t : D2E env - -type family D2AcE env where - D2AcE '[] = '[] - D2AcE (t : env) = TAccum t : D2AcE env - -d1 :: STy t -> STy (D1 t) -d1 STNil = STNil -d1 (STPair a b) = STPair (d1 a) (d1 b) -d1 (STEither a b) = STEither (d1 a) (d1 b) -d1 (STMaybe t) = STMaybe (d1 t) -d1 (STArr n t) = STArr n (d1 t) -d1 (STScal t) = STScal t -d1 STAccum{} = error "Accumulators not allowed in input program" - -d1e :: SList STy env -> SList STy (D1E env) -d1e SNil = SNil -d1e (t `SCons` env) = d1 t `SCons` d1e env - -d2 :: STy t -> STy (D2 t) -d2 STNil = STNil -d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b)) -d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b)) -d2 (STMaybe t) = STMaybe (d2 t) -d2 (STArr n t) = STMaybe (STArr n (d2 t)) -d2 (STScal t) = case t of - STI32 -> STNil - STI64 -> STNil - STF32 -> STScal STF32 - STF64 -> STScal STF64 - STBool -> STNil -d2 STAccum{} = error "Accumulators not allowed in input program" - -d2e :: SList STy env -> SList STy (D2E env) -d2e SNil = SNil -d2e (t `SCons` ts) = d2 t `SCons` d2e ts - -d2ace :: SList STy env -> SList STy (D2AcE env) -d2ace SNil = SNil -d2ace (t `SCons` ts) = STAccum t `SCons` d2ace ts - - -data CHADConfig = CHADConfig - { -- | D[let] will bind variables containing arrays in accumulator mode. - chcLetArrayAccum :: Bool - , -- | D[case] will bind variables containing arrays in accumulator mode. - chcCaseArrayAccum :: Bool - , -- | Introduce top-level arguments containing arrays in accumulator mode. - chcArgArrayAccum :: Bool - } - deriving (Show) - -defaultConfig :: CHADConfig -defaultConfig = CHADConfig - { chcLetArrayAccum = False - , chcCaseArrayAccum = False - , chcArgArrayAccum = False - } - -chcSetAccum :: CHADConfig -> CHADConfig -chcSetAccum c = c { chcLetArrayAccum = True - , chcCaseArrayAccum = True - , chcArgArrayAccum = True } - - ------------------------------------- LEMMAS ------------------------------------ - -indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) -indexTupD1Id SZ = Refl -indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl diff --git a/src/Util/IdGen.hs b/src/CHAD/Util/IdGen.hs index 3f6611d..d4fd945 100644 --- a/src/Util/IdGen.hs +++ b/src/CHAD/Util/IdGen.hs @@ -1,6 +1,6 @@ {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} -module Util.IdGen where +module CHAD.Util.IdGen where import Control.Monad.Fix import Control.Monad.Trans.State.Strict diff --git a/src/Interpreter.hs b/src/Interpreter.hs deleted file mode 100644 index 58d79a5..0000000 --- a/src/Interpreter.hs +++ /dev/null @@ -1,448 +0,0 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Interpreter ( - interpret, - interpretOpen, - Value(..), -) where - -import Control.Monad (foldM, join, when, forM_) -import Data.Bitraversable (bitraverse) -import Data.Char (isSpace) -import Data.Functor.Identity -import qualified Data.Functor.Product as Product -import Data.Int (Int64) -import Data.IORef -import System.IO (hPutStrLn, stderr) -import System.IO.Unsafe (unsafePerformIO) - -import Debug.Trace - -import Array -import AST -import AST.Pretty -import CHAD.Types -import Data -import Interpreter.Rep - - -newtype AcM s a = AcM { unAcM :: IO a } - deriving newtype (Functor, Applicative, Monad) - -runAcM :: (forall s. AcM s a) -> a -runAcM (AcM m) = unsafePerformIO m - -acmDebugLog :: String -> AcM s () -acmDebugLog s = AcM (hPutStrLn stderr s) - -data V t = V (STy t) (Rep t) - -interpret :: Ex '[] t -> Rep t -interpret = interpretOpen False SNil SNil - --- | Bool: whether to trace execution with debug prints (very verbose) -interpretOpen :: Bool -> SList STy env -> SList Value env -> Ex env t -> Rep t -interpretOpen prints env venv e = - runAcM $ - let ?depth = 0 - ?prints = prints - in interpret' (slistMap (\(Product.Pair t (Value v)) -> V t v) (slistZip env venv)) e - -interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) - => SList V env -> Ex env t -> AcM s (Rep t) -interpret' env e = do - let tenv = slistMap (\(V t _) -> t) env - let dep = ?depth - let lenlimit = max 20 (100 - dep) - let replace a b = map (\c -> if c == a then b else c) - let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..." - | otherwise = replace '\n' ' ' s - when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr tenv e) - res <- let ?depth = dep + 1 in interpret'Rec env e - when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res "" - return res - -interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList V env -> Ex env t -> AcM s (Rep t) -interpret'Rec env = \case - EVar _ _ i -> case slistIdx env i of V _ x -> return x - ELet _ a b -> do - x <- interpret' env a - let ?depth = ?depth - 1 in interpret' (V (typeOf a) x `SCons` env) b - expr | False && trace ("<i> " ++ takeWhile (not . isSpace) (show expr)) False -> undefined - EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b - EFst _ e -> fst <$> interpret' env e - ESnd _ e -> snd <$> interpret' env e - ENil _ -> return () - EInl _ _ e -> Left <$> interpret' env e - EInr _ _ e -> Right <$> interpret' env e - ECase _ e a b -> - let STEither t1 t2 = typeOf e - in interpret' env e >>= \case - Left x -> interpret' (V t1 x `SCons` env) a - Right y -> interpret' (V t2 y `SCons` env) b - ENothing _ _ -> return Nothing - EJust _ e -> Just <$> interpret' env e - EMaybe _ a b e -> - let STMaybe t1 = typeOf e - in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e - EConstArr _ _ _ v -> return v - EBuild _ dim a b -> do - sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a - arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b) - EFold1Inner _ _ a b c -> do - let t = typeOf b - let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a - x0 <- interpret' env b - arr <- interpret' env c - let sh `ShCons` n = arrayShape arr - arrayGenerateM sh $ \idx -> foldM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] - ESum1Inner _ e -> do - arr <- interpret' env e - let STArr _ (STScal t) = typeOf e - sh `ShCons` n = arrayShape arr - numericIsNum t $ return $ arrayGenerate sh $ \idx -> sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] - EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e) - EReplicate1Inner _ a b -> do - n <- fromIntegral @Int64 @Int <$> interpret' env a - arr <- interpret' env b - let sh = arrayShape arr - return $ arrayGenerate (sh `ShCons` n) (\(idx `IxCons` _) -> arrayIndex arr idx) - EMaximum1Inner _ e -> do - arr <- interpret' env e - let STArr _ (STScal t) = typeOf e - sh `ShCons` n = arrayShape arr - numericIsNum t $ return $ - arrayGenerate sh (\idx -> maximum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) - EMinimum1Inner _ e -> do - arr <- interpret' env e - let STArr _ (STScal t) = typeOf e - sh `ShCons` n = arrayShape arr - numericIsNum t $ return $ - arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) - EConst _ _ v -> return v - EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e - EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) - EIdx _ a b - | STArr n _ <- typeOf a - -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) - EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e - EOp _ op e -> interpretOp op <$> interpret' env e - ECustom _ t1 t2 _ pr _ _ e1 e2 -> do - e1' <- interpret' env e1 - e2' <- interpret' env e2 - interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr - EWith _ t e1 e2 -> do - initval <- interpret' env e1 - withAccum t (typeOf e2) initval $ \accum -> - interpret' (V (STAccum t) accum `SCons` env) e2 - EAccum _ t p e1 e2 e3 -> do - idx <- interpret' env e1 - val <- interpret' env e2 - accum <- interpret' env e3 - accumAddSparse t p accum idx val - EZero _ t -> do - return $ zeroD2 t - EPlus _ t a b -> do - a' <- interpret' env a - b' <- interpret' env b - return $ addD2s t a' b' - EOneHot _ t p a b -> do - a' <- interpret' env a - b' <- interpret' env b - return $ onehotD2 p t a' b' - EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s - -interpretOp :: SOp a t -> Rep a -> Rep t -interpretOp op arg = case op of - OAdd st -> numericIsNum st $ uncurry (+) arg - OMul st -> numericIsNum st $ uncurry (*) arg - ONeg st -> numericIsNum st $ negate arg - OLt st -> numericIsNum st $ uncurry (<) arg - OLe st -> numericIsNum st $ uncurry (<=) arg - OEq st -> styIsEq st $ uncurry (==) arg - ONot -> not arg - OAnd -> uncurry (&&) arg - OOr -> uncurry (||) arg - OIf -> if arg then Left () else Right () - ORound64 -> round arg - OToFl64 -> fromIntegral arg - ORecip st -> floatingIsFractional st $ recip arg - OExp st -> floatingIsFractional st $ exp arg - OLog st -> floatingIsFractional st $ log arg - OIDiv st -> integralIsIntegral st $ uncurry quot arg - OMod st -> integralIsIntegral st $ uncurry rem arg - where - styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r - styIsEq STI32 = id - styIsEq STI64 = id - styIsEq STF32 = id - styIsEq STF64 = id - styIsEq STBool = id - -zeroD2 :: STy t -> Rep (D2 t) -zeroD2 typ = case typ of - STNil -> () - STPair _ _ -> Nothing - STEither _ _ -> Nothing - STMaybe _ -> Nothing - STArr _ _ -> Nothing - STScal sty -> case sty of - STI32 -> () - STI64 -> () - STF32 -> 0.0 - STF64 -> 0.0 - STBool -> () - STAccum{} -> error "Zero of Accum" - -addD2s :: STy t -> Rep (D2 t) -> Rep (D2 t) -> Rep (D2 t) -addD2s typ a b = case typ of - STNil -> () - STPair t1 t2 -> case (a, b) of - (Nothing, _) -> b - (_, Nothing) -> a - (Just (x1, x2), Just (y1, y2)) -> Just (addD2s t1 x1 y1, addD2s t2 x2 y2) - STEither t1 t2 -> case (a, b) of - (Nothing, _) -> b - (_, Nothing) -> a - (Just (Left x), Just (Left y)) -> Just (Left (addD2s t1 x y)) - (Just (Right x), Just (Right y)) -> Just (Right (addD2s t2 x y)) - _ -> error "Plus of inconsistent Eithers" - STMaybe t -> case (a, b) of - (Nothing, _) -> b - (_, Nothing) -> a - (Just x, Just y) -> Just (addD2s t x y) - STArr _ t -> case (a, b) of - (Nothing, _) -> b - (_, Nothing) -> a - (Just x, Just y) -> - let sh1 = arrayShape x - sh2 = arrayShape y - in if | shapeSize sh1 == 0 -> Just y - | shapeSize sh2 == 0 -> Just x - | sh1 == sh2 -> Just $ arrayGenerateLin sh1 (\i -> addD2s t (arrayIndexLinear x i) (arrayIndexLinear y i)) - | otherwise -> error "Plus of inconsistently shaped arrays" - STScal sty -> case sty of - STI32 -> () - STI64 -> () - STF32 -> a + b - STF64 -> a + b - STBool -> () - STAccum{} -> error "Plus of Accum" - -onehotD2 :: SAcPrj p a b -> STy a -> Rep (AcIdx p a) -> Rep (D2 b) -> Rep (D2 a) -onehotD2 SAPHere _ _ val = val -onehotD2 (SAPFst prj) (STPair a b) idx val = Just (onehotD2 prj a idx val, zeroD2 b) -onehotD2 (SAPSnd prj) (STPair a b) idx val = Just (zeroD2 a, onehotD2 prj b idx val) -onehotD2 (SAPLeft prj) (STEither a _) idx val = Just (Left (onehotD2 prj a idx val)) -onehotD2 (SAPRight prj) (STEither _ b) idx val = Just (Right (onehotD2 prj b idx val)) -onehotD2 (SAPJust prj) (STMaybe a) idx val = Just (onehotD2 prj a idx val) -onehotD2 (SAPArrIdx prj _) (STArr n a) idx val = - Just $ runIdentity $ onehotArray (\idx' -> Identity (onehotD2 prj a idx' val)) (Identity (zeroD2 a)) n prj idx - -withAccum :: STy t -> STy a -> Rep (D2 t) -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep (D2 t)) -withAccum t _ initval f = AcM $ do - accum <- newAcSparse t SAPHere () initval - out <- unAcM $ f accum - val <- readAcSparse t accum - return (out, val) - -newAcZero :: STy t -> IO (RepAc t) -newAcZero = \case - STNil -> return () - STPair{} -> newIORef Nothing - STEither{} -> newIORef Nothing - STMaybe _ -> newIORef Nothing - STArr _ _ -> newIORef Nothing - STScal sty -> case sty of - STI32 -> return () - STI64 -> return () - STF32 -> newIORef 0.0 - STF64 -> newIORef 0.0 - STBool -> return () - STAccum{} -> error "Nested accumulators" - -newAcSparse :: STy a -> SAcPrj p a b -> Rep (AcIdx p a) -> Rep (D2 b) -> IO (RepAc a) -newAcSparse typ prj idx val = case (typ, prj) of - (STNil, SAPHere) -> return () - (STPair t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val - (STEither t1 t2, SAPHere) -> newIORef =<< traverse (bitraverse (newAcSparse t1 SAPHere ()) (newAcSparse t2 SAPHere ())) val - (STMaybe t1, SAPHere) -> newIORef =<< traverse (newAcSparse t1 SAPHere ()) val - (STArr _ t1, SAPHere) -> newIORef =<< traverse (traverse (newAcSparse t1 SAPHere ())) val - (STScal sty, SAPHere) -> case sty of - STI32 -> return () - STI64 -> return () - STF32 -> newIORef val - STF64 -> newIORef val - STBool -> return () - - (STPair t1 t2, SAPFst prj') -> - newIORef . Just =<< (,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2 - (STPair t1 t2, SAPSnd prj') -> - newIORef . Just =<< (,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val - - (STEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val - (STEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val - - (STMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val - - (STArr n t, SAPArrIdx prj' _) -> newIORef . Just =<< newAcArray n t prj' idx val - - (STAccum{}, _) -> error "Accumulators not allowed in source program" - -newAcArray :: SNat n -> STy a -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> Rep (D2 b) -> IO (Array n (RepAc a)) -newAcArray n t prj idx val = onehotArray (\idx' -> newAcSparse t prj idx' val) (newAcZero t) n prj idx - -onehotArray :: Monad m - => (Rep (AcIdx p a) -> m v) -- ^ the "one" - -> m v -- ^ the "zero" - -> SNat n -> SAcPrj p a b -> Rep (AcIdx (APArrIdx p) (TArr n a)) -> m (Array n v) -onehotArray mkone mkzero n _ ((arrindex', arrsh'), idx) = - let arrindex = unTupRepIdx IxNil IxCons n arrindex' - arrsh = unTupRepIdx ShNil ShCons n arrsh' - !linindex = toLinearIndex arrsh arrindex - in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero) - -readAcSparse :: STy t -> RepAc t -> IO (Rep (D2 t)) -readAcSparse typ val = case typ of - STNil -> return () - STPair t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val - STEither t1 t2 -> traverse (bitraverse (readAcSparse t1) (readAcSparse t2)) =<< readIORef val - STMaybe t -> traverse (readAcSparse t) =<< readIORef val - STArr _ t -> traverse (traverse (readAcSparse t)) =<< readIORef val - STScal sty -> case sty of - STI32 -> return () - STI64 -> return () - STF32 -> readIORef val - STF64 -> readIORef val - STBool -> return () - STAccum{} -> error "Nested accumulators" - -accumAddSparse :: STy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdx p a) -> Rep (D2 b) -> AcM s () -accumAddSparse typ prj ref idx val = case (typ, prj) of - (STNil, SAPHere) -> return () - - (STPair t1 t2, SAPHere) -> - case val of - Nothing -> return () - Just (val1, val2) -> - realiseMaybeSparse ref ((,) <$> newAcSparse t1 SAPHere () val1 - <*> newAcSparse t2 SAPHere () val2) - (\(ac1, ac2) -> do accumAddSparse t1 SAPHere ac1 () val1 - accumAddSparse t2 SAPHere ac2 () val2) - (STPair t1 t2, SAPFst prj') -> - realiseMaybeSparse ref ((,) <$> newAcSparse t1 prj' idx val <*> newAcZero t2) - (\(ac1, _) -> do accumAddSparse t1 prj' ac1 idx val) - (STPair t1 t2, SAPSnd prj') -> - realiseMaybeSparse ref ((,) <$> newAcZero t1 <*> newAcSparse t2 prj' idx val) - (\(_, ac2) -> do accumAddSparse t2 prj' ac2 idx val) - - (STEither{}, SAPHere) -> - case val of - Nothing -> return () - Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1 - Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2 - (STEither t1 _, SAPLeft prj') -> - realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val) - (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val - Right{} -> error "Mismatched Either in accumAddSparse (r +l)") - (STEither _ t2, SAPRight prj') -> - realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val) - (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val - Left{} -> error "Mismatched Either in accumAddSparse (l +r)") - - (STMaybe{}, SAPHere) -> - case val of - Nothing -> return () - Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val' - (STMaybe t1, SAPJust prj') -> - realiseMaybeSparse ref (newAcSparse t1 prj' idx val) - (\ac -> accumAddSparse t1 prj' ac idx val) - - (STArr _ t1, SAPHere) -> - case val of - Nothing -> return () - Just val' -> - realiseMaybeSparse ref - (arrayMapM (newAcSparse t1 SAPHere ()) val') - (\ac -> forM_ [0 .. arraySize ac - 1] $ \i -> - accumAddSparse t1 SAPHere (arrayIndexLinear ac i) () (arrayIndexLinear val' i)) - (STArr n t1, SAPArrIdx prj' _) -> - let ((arrindex', arrsh'), idx') = idx - arrindex = unTupRepIdx IxNil IxCons n arrindex' - arrsh = unTupRepIdx ShNil ShCons n arrsh' - linindex = toLinearIndex arrsh arrindex - in realiseMaybeSparse ref - (onehotArray (\_ -> newAcSparse t1 prj' idx' val) (newAcZero t1) n prj' idx) - (\ac -> accumAddSparse t1 prj' (arrayIndexLinear ac linindex) idx' val) - - (STScal sty, SAPHere) -> AcM $ case sty of - STI32 -> return () - STI64 -> return () - STF32 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STF64 -> atomicModifyIORef' ref (\x -> (x + val, ())) - STBool -> return () - - (STAccum{}, _) -> error "Accumulators not allowed in source program" - -realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s () -realiseMaybeSparse ref makeval modifyval = - -- Try modifying what's already in ref. The 'join' makes the snd - -- of the function's return value a _continuation_ that is run after - -- the critical section ends. - AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of - -- Oops, ref's contents was still sparse. Have to initialise - -- it first, then try again. - Nothing -> (ac, do val <- makeval - join $ atomicModifyIORef' ref $ \ac' -> case ac' of - Nothing -> (Just val, return ()) - Just val' -> (ac', unAcM $ modifyval val')) - -- Yep, ref already had a value in there, so we can just add - -- val' to it recursively. - Just val -> (ac, unAcM $ modifyval val) - - -numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r -numericIsNum STI32 = id -numericIsNum STI64 = id -numericIsNum STF32 = id -numericIsNum STF64 = id - -floatingIsFractional :: ScalIsFloating st ~ True => SScalTy st -> ((Floating (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsFloating st ~ True) => r) -> r -floatingIsFractional STF32 = id -floatingIsFractional STF64 = id - -integralIsIntegral :: ScalIsIntegral st ~ True => SScalTy st -> ((Integral (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsIntegral st ~ True) => r) -> r -integralIsIntegral STI32 = id -integralIsIntegral STI64 = id - -unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) - -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n -unTupRepIdx nil _ SZ _ = nil -unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i - -tupRepIdx :: (forall m. f (S m) -> (f m, Int)) - -> SNat n -> f n -> Rep (Tup (Replicate n TIx)) -tupRepIdx _ SZ _ = () -tupRepIdx uncons (SS n) tup = - let (tup', i) = uncons tup - in ((,) $! tupRepIdx uncons n tup') $! fromIntegral @Int @Int64 i - -ixUncons :: Index (S n) -> (Index n, Int) -ixUncons (IxCons idx i) = (idx, i) - -shUncons :: Shape (S n) -> (Shape n, Int) -shUncons (ShCons idx i) = (idx, i) diff --git a/src/Language.hs b/src/Language.hs deleted file mode 100644 index 4ed4eaa..0000000 --- a/src/Language.hs +++ /dev/null @@ -1,226 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE ExplicitForAll #-} -{-# LANGUAGE OverloadedLabels #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeApplications #-} -module Language ( - fromNamed, - NExpr, - Ex, - module Language, - module AST.Types, - module Data, - Lookup, -) where - -import Array -import AST -import AST.Types -import CHAD.Types -import Data -import Language.AST - - -data a :-> b = a :-> b - deriving (Show) -infixr 0 :-> - - -body :: NExpr env t -> NFun env env t -body = NBody - -lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t -lambda = NLam - -inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t -inline = inlineNFun - --- To be used to construct the argument list for 'inline'. --- --- > let fun = lambda @(TScal TF64) #x $ lambda @(TScal TF64) #y $ body $ #x + #y --- > in inline fun (SNil .$ 16 .$ 26) -(.$) :: SList f list -> f a -> SList f (a : list) -(.$) = flip SCons - - -let_ :: forall a t env name. Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t -let_ = NELet - -pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) -pair = NEPair - -fst_ :: NExpr env (TPair a b) -> NExpr env a -fst_ = NEFst - -snd_ :: NExpr env (TPair a b) -> NExpr env b -snd_ = NESnd - -nil :: NExpr env TNil -nil = NENil - -inl :: KnownTy b => NExpr env a -> NExpr env (TEither a b) -inl = NEInl knownTy - -inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b) -inr = NEInr knownTy - -case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c -case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2 - -nothing :: KnownTy a => NExpr env (TMaybe a) -nothing = NENothing knownTy - -just :: NExpr env a -> NExpr env (TMaybe a) -just = NEJust - -maybe_ :: NExpr env b -> (Var name a :-> NExpr ('(name, a) : env) b) -> NExpr env (TMaybe a) -> NExpr env b -maybe_ a (v :-> b) c = NEMaybe a v b c - -constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) -constArr_ x = - let ty = knownScalTy - in case scalRepIsShow ty of - Dict -> NEConstArr knownNat ty x - -build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t) -build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b)) - -build2 :: NExpr env TIx -> NExpr env TIx - -> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t) - -> NExpr env (TArr (S (S Z)) t) -build2 a1 a2 (v1 :-> v2 :-> b) = - NEBuild (SS (SS SZ)) - (pair (pair nil a1) a2) - #idx - (let_ v1 (snd_ (fst_ #idx)) $ - let_ v2 (NEDrop SZ (snd_ #idx)) $ - NEDrop (SS (SS SZ)) b) - -build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t) -build n a (v :-> b) = NEBuild n a v b - -map_ :: forall n a b env name. (KnownNat n, KnownTy a) - => (Var name a :-> NExpr ('(name, a) : env) b) - -> NExpr env (TArr n a) -> NExpr env (TArr n b) -map_ (v :-> a) b - | Dict <- styKnown (tTup (sreplicate (knownNat @n) tIx)) = - let_ #arg b $ - build knownNat (shape #arg) $ #i :-> - let_ v (#arg ! #i) $ - NEDrop (SS SZ) (NEDrop (SS SZ) a) - -fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) -fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3 - -sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) -sum1i e = NESum1Inner e - -unit :: NExpr env t -> NExpr env (TArr Z t) -unit = NEUnit - -replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t)) -replicate1i n a = NEReplicate1Inner n a - -maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) -maximum1i e = NEMaximum1Inner e - -minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) -minimum1i e = NEMinimum1Inner e - -const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t) -const_ x = - let ty = knownScalTy - in case scalRepIsShow ty of - Dict -> NEConst ty x - -idx0 :: NExpr env (TArr Z t) -> NExpr env t -idx0 = NEIdx0 - --- (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) --- (.!) = NEIdx1 --- infixl 9 .! - -(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t -(!) = NEIdx -infixl 9 ! - -shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -shape = NEShape - -length_ :: NExpr env (TArr N1 t) -> NExpr env TIx -length_ e = snd_ (shape e) - -oper :: SOp a t -> NExpr env a -> NExpr env t -oper = NEOp - -oper2 :: SOp (TPair a b) t -> NExpr env a -> NExpr env b -> NExpr env t -oper2 op a b = NEOp op (pair a b) - -error_ :: KnownTy t => String -> NExpr env t -error_ s = NEError knownTy s - -custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t) - -> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape)) - -> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b)) - -> NExpr env a -> NExpr env b - -> NExpr env t -custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 = - NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 - -with :: forall t a env acname. KnownTy t => NExpr env (D2 t) -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a (D2 t)) -with a (n :-> b) = NEWith (knownTy @t) a n b - -accum :: KnownTy t => SAcPrj p t a -> NExpr env (AcIdx p t) -> NExpr env (D2 a) -> NExpr env (TAccum t) -> NExpr env TNil -accum p a b c = NEAccum knownTy p a b c - - -(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -a .== b = oper (OEq knownScalTy) (pair a b) -infix 4 .== - -(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -a .< b = oper (OLt knownScalTy) (pair a b) -infix 4 .< - -(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -(.>) = flip (.<) -infix 4 .> - -(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -a .<= b = oper (OLe knownScalTy) (pair a b) -infix 4 .<= - -(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -(.>=) = flip (.<=) -infix 4 .>= - -not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -not_ = oper ONot - -and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) -and_ = oper2 OAnd -infixr 3 `and_` - -or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) -or_ = oper2 OOr -infixr 2 `or_` - -mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a) -mod_ = oper2 (OMod knownScalTy) -infixl 7 `mod_` - --- | The first alternative is the True case; the second is the False case. -if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t -if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b) - -round_ :: NExpr env (TScal TF64) -> NExpr env (TScal TI64) -round_ = oper ORound64 - -toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64) -toFloat_ = oper OToFl64 - -idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t) -idiv = oper2 (OIDiv knownScalTy) -infixl 7 `idiv` diff --git a/src/Simplify.hs b/src/Simplify.hs deleted file mode 100644 index e0ab37b..0000000 --- a/src/Simplify.hs +++ /dev/null @@ -1,348 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Simplify ( - simplifyN, simplifyFix, - SimplifyConfig(..), simplifyWith, simplifyFixWith, -) where - -import Data.Function (fix) -import Data.Monoid (Any(..)) -import Data.Type.Equality (testEquality) - -import AST -import AST.Count -import CHAD.Types -import Data - - --- | This has no fields now, hence this type is useless as-is. When debugging, however, it's useful to be able to add some. -data SimplifyConfig = SimplifyConfig - -defaultSimplifyConfig :: SimplifyConfig -defaultSimplifyConfig = SimplifyConfig - -simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t -simplifyN 0 = id -simplifyN n = simplifyN (n - 1) . simplify - -simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t -simplify = - let ?accumInScope = checkAccumInScope @env knownEnv - ?config = defaultSimplifyConfig - in snd . simplify' - -simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t -simplifyWith config = - let ?accumInScope = checkAccumInScope @env knownEnv - ?config = config - in snd . simplify' - -simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t -simplifyFix = simplifyFixWith defaultSimplifyConfig - -simplifyFixWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t -simplifyFixWith config = - let ?accumInScope = checkAccumInScope @env knownEnv - ?config = config - in fix $ \loop e -> - let (Any act, e') = simplify' e - in if act then loop e' else e' - -simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig) => Ex env t -> (Any, Ex env t) -simplify' = \case - -- inlining - ELet _ rhs body - | cheapExpr rhs - -> acted $ simplify' (substInline rhs body) - - | Occ lexOcc runOcc <- occCount IZ body - , ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One) -- without effects, normal rules apply - || (lexOcc == One && runOcc == One) -- with effects, linear inlining is still allowed, but weakening is not - -> acted $ simplify' (substInline rhs body) - - -- let splitting / let peeling - ELet _ (EPair _ a b) body -> - acted $ simplify' $ - ELet ext a $ - ELet ext (weakenExpr WSink b) $ - subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ) - IS i -> EVar ext t (IS (IS i))) - body - ELet _ (EJust _ a) body -> - acted $ simplify' $ ELet ext a $ subst0 (EJust ext (EVar ext (typeOf a) IZ)) body - ELet _ (EInl _ t2 a) body -> - acted $ simplify' $ ELet ext a $ subst0 (EInl ext t2 (EVar ext (typeOf a) IZ)) body - ELet _ (EInr _ t1 a) body -> - acted $ simplify' $ ELet ext a $ subst0 (EInr ext t1 (EVar ext (typeOf a) IZ)) body - - -- let rotation - ELet _ (ELet _ rhs a) b -> - acted $ simplify' $ - ELet ext rhs $ - ELet ext a $ - weakenExpr (WCopy WSink) (snd (simplify' b)) - - -- beta rules for products - EFst _ (EPair _ e e') - | not (hasAdds e') -> acted $ simplify' e - | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e) - ESnd _ (EPair _ e' e) - | not (hasAdds e') -> acted $ simplify' e - | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e) - - -- beta rules for coproducts - ECase _ (EInl _ _ e) rhs _ -> acted $ simplify' (ELet ext e rhs) - ECase _ (EInr _ _ e) _ rhs -> acted $ simplify' (ELet ext e rhs) - - -- beta rules for maybe - EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1 - EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1 - - -- let floating - EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body)) - ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body)) - ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) - EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body)) - EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) - EAccum _ t p e1 (ELet _ rhs body) acc -> - acted $ simplify' $ - ELet ext rhs $ - EAccum ext t p (weakenExpr WSink e1) body (weakenExpr WSink acc) - - -- let () = e in () ~> e - ELet _ e1 (ENil _) | STNil <- typeOf e1 -> - acted $ simplify' e1 - - -- projection down-commuting - EFst _ (ECase _ e1 e2 e3) -> - acted $ simplify' $ - ECase ext e1 (EFst ext e2) (EFst ext e3) - ESnd _ (ECase _ e1 e2 e3) -> - acted $ simplify' $ - ECase ext e1 (ESnd ext e2) (ESnd ext e3) - - -- TODO: array indexing (index of build, index of fold) - - -- TODO: beta rules for maybe - - -- TODO: constant folding for operations - - -- monoid rules - EAccum _ t p e1 e2 acc -> do - e1' <- simplify' e1 - e2' <- simplify' e2 - acc' <- simplify' acc - simplifyOneHotTerm (OneHotTerm t p e1' e2') - (Any True, ENil ext) - (\e -> (Any False, EAccum ext t SAPHere (ENil ext) e acc')) - (\(OneHotTerm t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc')) - EPlus _ _ (EZero _ _) e -> acted $ simplify' e - EPlus _ _ e (EZero _ _) -> acted $ simplify' e - EOneHot _ t p e1 e2 -> do - e1' <- simplify' e1 - e2' <- simplify' e2 - simplifyOneHotTerm (OneHotTerm t p e1' e2') - (Any True, EZero ext t) - (\e -> (Any True, e)) - (\(OneHotTerm t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2'')) - - -- type-specific equations for plus - EPlus _ STNil _ _ -> (Any True, ENil ext) - - EPlus _ (STPair t1 t2) (EJust _ (EPair _ a1 b1)) (EJust _ (EPair _ a2 b2)) -> - acted $ simplify' $ EJust ext (EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2)) - EPlus _ STPair{} ENothing{} e -> acted $ simplify' e - EPlus _ STPair{} e ENothing{} -> acted $ simplify' e - - EPlus _ (STEither t1 _) (EJust _ (EInl _ dt2 a1)) (EJust _ (EInl _ _ a2)) -> - acted $ simplify' $ EJust ext (EInl ext dt2 (EPlus ext t1 a1 a2)) - EPlus _ (STEither _ t2) (EJust _ (EInr _ dt1 b1)) (EJust _ (EInr _ _ b2)) -> - acted $ simplify' $ EJust ext (EInr ext dt1 (EPlus ext t2 b1 b2)) - EPlus _ STEither{} ENothing{} e -> acted $ simplify' e - EPlus _ STEither{} e ENothing{} -> acted $ simplify' e - - EPlus _ (STMaybe t) (EJust _ e1) (EJust _ e2) -> - acted $ simplify' $ EJust ext (EPlus ext t e1 e2) - EPlus _ STMaybe{} ENothing{} e -> acted $ simplify' e - EPlus _ STMaybe{} e ENothing{} -> acted $ simplify' e - - -- fallback recursion - EVar _ t i -> pure $ EVar ext t i - ELet _ a b -> ELet ext <$> simplify' a <*> simplify' b - EPair _ a b -> EPair ext <$> simplify' a <*> simplify' b - EFst _ e -> EFst ext <$> simplify' e - ESnd _ e -> ESnd ext <$> simplify' e - ENil _ -> pure $ ENil ext - EInl _ t e -> EInl ext t <$> simplify' e - EInr _ t e -> EInr ext t <$> simplify' e - ECase _ e a b -> ECase ext <$> simplify' e <*> simplify' a <*> simplify' b - ENothing _ t -> pure $ ENothing ext t - EJust _ e -> EJust ext <$> simplify' e - EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e - EConstArr _ n t v -> pure $ EConstArr ext n t v - EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b - EFold1Inner _ cm a b c -> EFold1Inner ext cm <$> simplify' a <*> simplify' b <*> simplify' c - ESum1Inner _ e -> ESum1Inner ext <$> simplify' e - EUnit _ e -> EUnit ext <$> simplify' e - EReplicate1Inner _ a b -> EReplicate1Inner ext <$> simplify' a <*> simplify' b - EMaximum1Inner _ e -> EMaximum1Inner ext <$> simplify' e - EMinimum1Inner _ e -> EMinimum1Inner ext <$> simplify' e - EConst _ t v -> pure $ EConst ext t v - EIdx0 _ e -> EIdx0 ext <$> simplify' e - EIdx1 _ a b -> EIdx1 ext <$> simplify' a <*> simplify' b - EIdx _ a b -> EIdx ext <$> simplify' a <*> simplify' b - EShape _ e -> EShape ext <$> simplify' e - EOp _ op e -> EOp ext op <$> simplify' e - ECustom _ s t p a b c e1 e2 -> - ECustom ext s t p - <$> (let ?accumInScope = False in simplify' a) - <*> (let ?accumInScope = False in simplify' b) - <*> (let ?accumInScope = False in simplify' c) - <*> simplify' e1 <*> simplify' e2 - EWith _ t e1 e2 -> EWith ext t <$> simplify' e1 <*> (let ?accumInScope = True in simplify' e2) - EZero _ t -> pure $ EZero ext t - EPlus _ t a b -> EPlus ext t <$> simplify' a <*> simplify' b - EError _ t s -> pure $ EError ext t s - -acted :: (Any, a) -> (Any, a) -acted (_, x) = (Any True, x) - -cheapExpr :: Expr x env t -> Bool -cheapExpr = \case - EVar{} -> True - ENil{} -> True - EConst{} -> True - EFst _ e -> cheapExpr e - ESnd _ e -> cheapExpr e - _ -> False - --- | This can be made more precise by tracking (and not counting) adds on --- locally eliminated accumulators. -hasAdds :: Expr x env t -> Bool -hasAdds = \case - EVar _ _ _ -> False - ELet _ rhs body -> hasAdds rhs || hasAdds body - EPair _ a b -> hasAdds a || hasAdds b - EFst _ e -> hasAdds e - ESnd _ e -> hasAdds e - ENil _ -> False - EInl _ _ e -> hasAdds e - EInr _ _ e -> hasAdds e - ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b - ENothing _ _ -> False - EJust _ e -> hasAdds e - EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e - EConstArr _ _ _ _ -> False - EBuild _ _ a b -> hasAdds a || hasAdds b - EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c - ESum1Inner _ e -> hasAdds e - EUnit _ e -> hasAdds e - EReplicate1Inner _ a b -> hasAdds a || hasAdds b - EMaximum1Inner _ e -> hasAdds e - EMinimum1Inner _ e -> hasAdds e - ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e - EConst _ _ _ -> False - EIdx0 _ e -> hasAdds e - EIdx1 _ a b -> hasAdds a || hasAdds b - EIdx _ a b -> hasAdds a || hasAdds b - EShape _ e -> hasAdds e - EOp _ _ e -> hasAdds e - EWith _ _ a b -> hasAdds a || hasAdds b - EAccum _ _ _ _ _ _ -> True - EZero _ _ -> False - EPlus _ _ a b -> hasAdds a || hasAdds b - EOneHot _ _ _ a b -> hasAdds a || hasAdds b - EError _ _ _ -> False - -checkAccumInScope :: SList STy env -> Bool -checkAccumInScope = \case SNil -> False - SCons t env -> check t || checkAccumInScope env - where - check :: STy t -> Bool - check STNil = False - check (STPair s t) = check s || check t - check (STEither s t) = check s || check t - check (STMaybe t) = check t - check (STArr _ t) = check t - check (STScal _) = False - check STAccum{} = True - -data OneHotTerm env p a b where - OneHotTerm :: STy a -> SAcPrj p a b -> Ex env (AcIdx p a) -> Ex env (D2 b) -> OneHotTerm env p a b -deriving instance Show (OneHotTerm env p a b) - -simplifyOneHotTerm :: OneHotTerm env p a b - -> (Any, r) -- ^ Zero case (onehot is actually zero) - -> (Ex env (D2 a) -> (Any, r)) -- ^ Trivial case (no zeros in onehot) - -> (forall p' b'. OneHotTerm env p' a b' -> (Any, r)) - -> (Any, r) -simplifyOneHotTerm (OneHotTerm _ _ _ (EZero _ _)) kzero _ _ = kzero - -simplifyOneHotTerm (OneHotTerm t1 prj1 idx1 (EOneHot _ t2 prj2 idx2 val)) kzero ktriv k - | Just Refl <- testEquality (acPrjTy prj1 t1) t2 - = do (Any True, ()) -- record, whatever happens later, that we've modified something - concatOneHots t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 -> - simplifyOneHotTerm (OneHotTerm t1 prj12 idx12 val) kzero ktriv k - -simplifyOneHotTerm (OneHotTerm t SAPHere idx e) kzero ktriv k = case (t, e) of - (STNil, _) -> kzero - - (STPair{}, ENothing _ _) -> kzero - (STPair{}, EJust _ (EPair _ e1 EZero{})) -> - simplifyOneHotTerm (OneHotTerm t (SAPFst SAPHere) idx e1) kzero ktriv k - (STPair{}, EJust _ (EPair _ EZero{} e2)) -> - simplifyOneHotTerm (OneHotTerm t (SAPSnd SAPHere) idx e2) kzero ktriv k - - (STEither{}, ENothing _ _) -> kzero - (STEither{}, EJust _ (EInl _ _ e1)) -> - simplifyOneHotTerm (OneHotTerm t (SAPLeft SAPHere) idx e1) kzero ktriv k - (STEither{}, EJust _ (EInr _ _ e2)) -> - simplifyOneHotTerm (OneHotTerm t (SAPRight SAPHere) idx e2) kzero ktriv k - - (STMaybe{}, ENothing _ _) -> kzero - (STMaybe{}, EJust _ e1) -> - simplifyOneHotTerm (OneHotTerm t (SAPJust SAPHere) idx e1) kzero ktriv k - - (STArr{}, ENothing _ _) -> kzero - - (STScal STI32, _) -> kzero - (STScal STI64, _) -> kzero - (STScal STF32, EConst _ _ 0.0) -> kzero - (STScal STF64, EConst _ _ 0.0) -> kzero - (STScal STBool, _) -> kzero - - _ -> ktriv e - -simplifyOneHotTerm term _ _ k = k term - -concatOneHots :: STy a - -> SAcPrj p1 a b -> Ex env (AcIdx p1 a) - -> SAcPrj p2 b c -> Ex env (AcIdx p2 b) - -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx p12 a) -> r) -> r -concatOneHots t1 prj1 idx1 prj2 idx2 k = case (t1, prj1) of - (_, SAPHere) -> k prj2 idx2 - - (STPair a _, SAPFst prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPFst prj12) idx12 - (STPair _ b, SAPSnd prj1') -> - concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPSnd prj12) idx12 - - (STEither a _, SAPLeft prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12 - (STEither _ b, SAPRight prj1') -> - concatOneHots b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12 - - (STMaybe a, SAPJust prj1') -> - concatOneHots a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12 - - (STArr n a, SAPArrIdx prj1' _) -> - concatOneHots a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 -> - k (SAPArrIdx prj12 n) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12) |
