diff options
-rw-r--r-- | src/AST.hs | 10 | ||||
-rw-r--r-- | src/AST/Count.hs | 2 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 2 | ||||
-rw-r--r-- | src/AST/Weaken.hs | 4 | ||||
-rw-r--r-- | src/CHAD.hs | 27 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 18 | ||||
-rw-r--r-- | src/Interpreter.hs | 7 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 1 | ||||
-rw-r--r-- | src/Language.hs | 4 | ||||
-rw-r--r-- | src/Language/AST.hs | 4 | ||||
-rw-r--r-- | src/Simplify.hs | 4 | ||||
-rw-r--r-- | test/Main.hs | 3 |
12 files changed, 48 insertions, 38 deletions
@@ -92,7 +92,7 @@ data Expr x env t where EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) - EIdx :: x t -> SNat n -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t + EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t @@ -194,7 +194,7 @@ typeOf = \case EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t - EIdx _ _ e _ | STArr _ t <- typeOf e -> t + EIdx _ e _ | STArr _ t <- typeOf e -> t EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx) EOp _ op _ -> opt2 op @@ -267,7 +267,7 @@ subst' f w = \case EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) - EIdx x n e es -> EIdx x n (subst' f w e) (subst' f w es) + EIdx x e es -> EIdx x (subst' f w e) (subst' f w es) EShape x e -> EShape x (subst' f w e) EOp x op e -> EOp x op (subst' f w e) EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) @@ -339,5 +339,5 @@ ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) $ let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ - in EIdx ext n (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) - (EFst ext arg) + in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) + (EFst ext arg) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index dbec446..31720a5 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -123,7 +123,7 @@ occCountGeneral onehot unpush alter many = go WId EConst{} -> mempty EIdx0 _ e -> re e EIdx1 _ a b -> re a <> re b - EIdx _ _ a b -> re a <> re b + EIdx _ a b -> re a <> re b EShape _ e -> re e EOp _ _ e -> re e EWith a b -> re a <> re1 b diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index d811912..b50506a 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -159,7 +159,7 @@ ppExpr' d val = \case b' <- ppExpr' 9 val b return $ showParen (d > 8) $ a' . showString " .! " . b' - EIdx _ _ a b -> do + EIdx _ a b -> do a' <- ppExpr' 9 val a b' <- ppExpr' 10 val b return $ showParen (d > 8) $ diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index 0a1e4ce..ecd7bc9 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -48,7 +48,7 @@ data env :> env' where -> Append pre (t : env) :> t : Append pre env' WSwap :: forall env as bs. SList (Const ()) as -> SList (Const ()) bs -> Append as (Append bs env) :> Append bs (Append as env) - WStack :: forall as bs env1 env2. SList (Const ()) as -> SList (Const ()) bs + WStack :: forall env1 env2 as bs. SList (Const ()) as -> SList (Const ()) bs -> as :> bs -> env1 :> env2 -> Append as env1 :> Append bs env2 deriving instance Show (env :> env') @@ -74,7 +74,7 @@ WSwap @env (as :: SList _ as) (bs :: SList _ bs) @> i = Right i' -> case splitIdx @env bs i' of Left j -> indexRaiseAbove @(Append as env) bs j Right j -> indexSinks bs (indexSinks as j) -WStack @as @bs @env1 @env2 as bs wlo whi @> i = +WStack @env1 @env2 as bs wlo whi @> i = case splitIdx @env1 as i of Left i' -> indexRaiseAbove @env2 bs (wlo @> i') Right i' -> indexSinks bs (whi @> i') diff --git a/src/CHAD.hs b/src/CHAD.hs index 4694ac4..e77dbe7 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -981,8 +981,8 @@ drev des = \case (#e0 :++: #ix :++: #sh :++: #she0 :++: #d1env))))) (EBuild ext ndim (EVar ext shty (IS IZ)) - (ELet ext (EIdx ext ndim (EVar ext (STArr ndim tapety) (IS IZ)) - (EVar ext shty IZ)) $ + (ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (IS IZ)) + (EVar ext shty IZ)) $ let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ in letBinds rebinds $ weakenExpr (autoWeak (#ix (shty `SCons` SNil) @@ -1004,11 +1004,11 @@ drev des = \case makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ -- the cotangent for this element - ELet ext (EIdx ext ndim (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) - (EVar ext shty IZ)) $ + ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ -- the tape for this element - ELet ext (EIdx ext ndim (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) - (EVar ext shty (IS IZ))) $ + ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) + (EVar ext shty (IS IZ))) $ let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ in letBinds rebinds $ weakenExpr (autoWeak (#d (auto1 @(D2 eltty)) @@ -1073,19 +1073,20 @@ drev des = \case (EVar ext (STArr n (d2 eltty)) (IS IZ))) $ weakenExpr (WCopy (WSink .> WSink)) e2) - EIdx _ n e ei + EIdx _ e ei -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil - , STArr _ eltty <- typeOf e + , STArr n eltty <- typeOf e , Refl <- indexTupD1Id n -> - Ret (binds `BPush` (STArr n (d1 eltty), e1)) - (EIdx ext n (EVar ext (STArr n (d1 eltty)) IZ) - (weakenExpr WSink ei1)) + Ret (binds `BPush` (STArr n (d1 eltty), e1) + `BPush` (tTup (sreplicate n tIx), EShape ext (EVar ext (typeOf e1) IZ))) + (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS IZ)) + (weakenExpr (WSink .> WSink) ei1)) sub - (ELet ext (EBuild ext n (EShape ext (EVar ext (STArr n (d1 eltty)) (IS IZ))) + (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS IZ)) (EVar ext (d2 eltty) (IS IZ))) $ - weakenExpr (WCopy (WSink .> WSink)) e2) + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) EShape _ e -- Allowed to ignore e2 here because the output of EShape is discrete, diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index a93b8e6..f02b93e 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -171,8 +171,10 @@ dfwdDN = \case (EConst ext t x) EIdx0 _ e -> EIdx0 ext (dfwdDN e) EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b) - EIdx _ n a b - | Refl <- dnPreservesTupIx n -> EIdx ext n (dfwdDN a) (dfwdDN b) + EIdx _ a b + | STArr n _ <- typeOf a + , Refl <- dnPreservesTupIx n + -> EIdx ext (dfwdDN a) (dfwdDN b) EShape _ e | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e) EOp _ op e -> dop op (dfwdDN e) @@ -191,8 +193,8 @@ emap f arr = let STArr n t = typeOf arr in ELet ext arr $ EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ - ELet ext (EIdx ext n (EVar ext (STArr n t) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) + (EVar ext (tTup (sreplicate n tIx)) IZ)) $ weakenExpr (WCopy (WSink .> WSink)) f ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) @@ -202,7 +204,7 @@ ezip a b = in ELet ext a $ ELet ext (weakenExpr WSink b) $ EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ - EPair ext (EIdx ext n (EVar ext (STArr n t1) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) - (EIdx ext n (EVar ext (STArr n t2) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) + EPair ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) + (EIdx ext (EVar ext (STArr n t2) (IS IZ)) + (EVar ext (tTup (sreplicate n tIx)) IZ)) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 7be1c4b..3fb5d7b 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -23,6 +23,7 @@ import Data.Char (isSpace) import Data.Kind (Type) import Data.Int (Int64) import Data.IORef +import GHC.Stack (HasCallStack) import System.IO.Unsafe (unsafePerformIO) import Debug.Trace @@ -33,7 +34,6 @@ import CHAD.Types import Data import Interpreter.Rep import Data.Bifunctor (bimap) -import GHC.Stack (HasCallStack) newtype AcM s a = AcM { unAcM :: IO a } @@ -95,7 +95,9 @@ interpret' env = \case EConst _ _ v -> return v EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) - EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) + EIdx _ a b + | STArr n _ <- typeOf a + -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e EOp _ op e -> interpretOp op <$> interpret' env e EWith e1 e2 -> do @@ -135,6 +137,7 @@ zeroD2 typ = case typ of STPair _ _ -> Left () STEither _ _ -> Left () STMaybe _ -> Nothing + STArr SZ t -> arrayUnit (zeroD2 t) STArr n _ -> emptyArray n STScal sty -> case sty of STI32 -> () diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index ed307c0..5c20183 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -25,6 +25,7 @@ type family RepAcSparse t where RepAcSparse (TPair a b) = IORef (RepAcSparse a, RepAcSparse b) RepAcSparse (TEither a b) = TypeError (Text "Non-sparse coproduct is not a monoid") RepAcSparse (TMaybe t) = IORef (Maybe (RepAcDense t)) -- allow the value to be dense, because the Maybe's zero can be used for the contents + -- TODO: an empty array is invalid for a zero-dimensional array, so zero-dimensional arrays don't actually have an O(1) zero. RepAcSparse (TArr n t) = IORef (Array n (RepAcSparse t)) -- empty array is zero RepAcSparse (TScal sty) = IORef (ScalRep sty) RepAcSparse (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators") diff --git a/src/Language.hs b/src/Language.hs index c2b844e..a025236 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -88,8 +88,8 @@ idx0 = NEIdx0 (.!) = NEIdx1 infixl 9 .! -(!) :: KnownNat n => NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t -(!) = NEIdx knownNat +(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t +(!) = NEIdx infixl 9 ! shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 0945dd9..409d24d 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -52,7 +52,7 @@ data NExpr env t where NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) - NEIdx :: SNat n -> NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t + NEIdx :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) NEOp :: SOp a t -> NExpr env a -> NExpr env t @@ -131,7 +131,7 @@ fromNamedExpr val = \case NEConst t x -> EConst ext t x NEIdx0 e -> EIdx0 ext (go e) NEIdx1 a b -> EIdx1 ext (go a) (go b) - NEIdx n a b -> EIdx ext n (go a) (go b) + NEIdx a b -> EIdx ext (go a) (go b) NEShape e -> EShape ext (go e) NEOp op e -> EOp ext op (go e) diff --git a/src/Simplify.hs b/src/Simplify.hs index 3f4c8e3..5829a8b 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -94,7 +94,7 @@ simplify' = \case EConst _ t v -> EConst ext t v EIdx0 _ e -> EIdx0 ext (simplify' e) EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b) - EIdx _ n a b -> EIdx ext n (simplify' a) (simplify' b) + EIdx _ a b -> EIdx ext (simplify' a) (simplify' b) EShape _ e -> EShape ext (simplify' e) EOp _ op e -> EOp ext op (simplify' e) EWith e1 e2 -> EWith (simplify' e1) (let ?accumInScope = True in simplify' e2) @@ -136,7 +136,7 @@ hasAdds = \case EConst _ _ _ -> False EIdx0 _ e -> hasAdds e EIdx1 _ a b -> hasAdds a || hasAdds b - EIdx _ _ a b -> hasAdds a || hasAdds b + EIdx _ a b -> hasAdds a || hasAdds b EShape _ e -> hasAdds e EOp _ _ e -> hasAdds e EWith a b -> hasAdds a || hasAdds b diff --git a/test/Main.hs b/test/Main.hs index f779352..a3fa484 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -241,6 +241,9 @@ tests = checkParallel $ Group "AD" let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $ fst_ #q * #x + snd_ #q * fst_ #p) + ,("build0 const", adTest $ fromNamed $ lambda @(TScal TF64) #x $ body $ + idx0 $ build SZ nil $ #idx :-> const_ 0.0) + ,("build0", adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $ idx0 $ build SZ (shape #x) $ #idx :-> #x ! #idx) |