summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/AST.hs10
-rw-r--r--src/AST/Count.hs2
-rw-r--r--src/AST/Pretty.hs2
-rw-r--r--src/AST/Weaken.hs4
-rw-r--r--src/CHAD.hs27
-rw-r--r--src/ForwardAD/DualNumbers.hs18
-rw-r--r--src/Interpreter.hs7
-rw-r--r--src/Interpreter/Rep.hs1
-rw-r--r--src/Language.hs4
-rw-r--r--src/Language/AST.hs4
-rw-r--r--src/Simplify.hs4
-rw-r--r--test/Main.hs3
12 files changed, 48 insertions, 38 deletions
diff --git a/src/AST.hs b/src/AST.hs
index e2702ab..af137b2 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -92,7 +92,7 @@ data Expr x env t where
EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t
EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)
- EIdx :: x t -> SNat n -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t
+ EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t
EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx))
EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
@@ -194,7 +194,7 @@ typeOf = \case
EConst _ t _ -> STScal t
EIdx0 _ e | STArr _ t <- typeOf e -> t
EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t
- EIdx _ _ e _ | STArr _ t <- typeOf e -> t
+ EIdx _ e _ | STArr _ t <- typeOf e -> t
EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx)
EOp _ op _ -> opt2 op
@@ -267,7 +267,7 @@ subst' f w = \case
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (subst' f w e)
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
- EIdx x n e es -> EIdx x n (subst' f w e) (subst' f w es)
+ EIdx x e es -> EIdx x (subst' f w e) (subst' f w es)
EShape x e -> EShape x (subst' f w e)
EOp x op e -> EOp x op (subst' f w e)
EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
@@ -339,5 +339,5 @@ ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx :
ebuildUp1 n sh size f =
EBuild ext (SS n) (EPair ext sh size) $
let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ
- in EIdx ext n (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f))
- (EFst ext arg)
+ in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f))
+ (EFst ext arg)
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index dbec446..31720a5 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -123,7 +123,7 @@ occCountGeneral onehot unpush alter many = go WId
EConst{} -> mempty
EIdx0 _ e -> re e
EIdx1 _ a b -> re a <> re b
- EIdx _ _ a b -> re a <> re b
+ EIdx _ a b -> re a <> re b
EShape _ e -> re e
EOp _ _ e -> re e
EWith a b -> re a <> re1 b
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index d811912..b50506a 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -159,7 +159,7 @@ ppExpr' d val = \case
b' <- ppExpr' 9 val b
return $ showParen (d > 8) $ a' . showString " .! " . b'
- EIdx _ _ a b -> do
+ EIdx _ a b -> do
a' <- ppExpr' 9 val a
b' <- ppExpr' 10 val b
return $ showParen (d > 8) $
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index 0a1e4ce..ecd7bc9 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -48,7 +48,7 @@ data env :> env' where
-> Append pre (t : env) :> t : Append pre env'
WSwap :: forall env as bs. SList (Const ()) as -> SList (Const ()) bs
-> Append as (Append bs env) :> Append bs (Append as env)
- WStack :: forall as bs env1 env2. SList (Const ()) as -> SList (Const ()) bs
+ WStack :: forall env1 env2 as bs. SList (Const ()) as -> SList (Const ()) bs
-> as :> bs -> env1 :> env2
-> Append as env1 :> Append bs env2
deriving instance Show (env :> env')
@@ -74,7 +74,7 @@ WSwap @env (as :: SList _ as) (bs :: SList _ bs) @> i =
Right i' -> case splitIdx @env bs i' of
Left j -> indexRaiseAbove @(Append as env) bs j
Right j -> indexSinks bs (indexSinks as j)
-WStack @as @bs @env1 @env2 as bs wlo whi @> i =
+WStack @env1 @env2 as bs wlo whi @> i =
case splitIdx @env1 as i of
Left i' -> indexRaiseAbove @env2 bs (wlo @> i')
Right i' -> indexSinks bs (whi @> i')
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 4694ac4..e77dbe7 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -981,8 +981,8 @@ drev des = \case
(#e0 :++: #ix :++: #sh :++: #she0 :++: #d1env)))))
(EBuild ext ndim
(EVar ext shty (IS IZ))
- (ELet ext (EIdx ext ndim (EVar ext (STArr ndim tapety) (IS IZ))
- (EVar ext shty IZ)) $
+ (ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (IS IZ))
+ (EVar ext shty IZ)) $
let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ
in letBinds rebinds $
weakenExpr (autoWeak (#ix (shty `SCons` SNil)
@@ -1004,11 +1004,11 @@ drev des = \case
makeAccumulators @_ @_ @(TArr ndim TNil) envPro $
EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
-- the cotangent for this element
- ELet ext (EIdx ext ndim (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
- (EVar ext shty IZ)) $
+ ELet ext (EIdx ext (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ))
+ (EVar ext shty IZ)) $
-- the tape for this element
- ELet ext (EIdx ext ndim (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
- (EVar ext shty (IS IZ))) $
+ ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
+ (EVar ext shty (IS IZ))) $
let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ
in letBinds rebinds $
weakenExpr (autoWeak (#d (auto1 @(D2 eltty))
@@ -1073,19 +1073,20 @@ drev des = \case
(EVar ext (STArr n (d2 eltty)) (IS IZ))) $
weakenExpr (WCopy (WSink .> WSink)) e2)
- EIdx _ n e ei
+ EIdx _ e ei
-- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
| Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
<- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil
- , STArr _ eltty <- typeOf e
+ , STArr n eltty <- typeOf e
, Refl <- indexTupD1Id n ->
- Ret (binds `BPush` (STArr n (d1 eltty), e1))
- (EIdx ext n (EVar ext (STArr n (d1 eltty)) IZ)
- (weakenExpr WSink ei1))
+ Ret (binds `BPush` (STArr n (d1 eltty), e1)
+ `BPush` (tTup (sreplicate n tIx), EShape ext (EVar ext (typeOf e1) IZ)))
+ (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS IZ))
+ (weakenExpr (WSink .> WSink) ei1))
sub
- (ELet ext (EBuild ext n (EShape ext (EVar ext (STArr n (d1 eltty)) (IS IZ)))
+ (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS IZ))
(EVar ext (d2 eltty) (IS IZ))) $
- weakenExpr (WCopy (WSink .> WSink)) e2)
+ weakenExpr (WCopy (WSink .> WSink .> WSink)) e2)
EShape _ e
-- Allowed to ignore e2 here because the output of EShape is discrete,
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index a93b8e6..f02b93e 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -171,8 +171,10 @@ dfwdDN = \case
(EConst ext t x)
EIdx0 _ e -> EIdx0 ext (dfwdDN e)
EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b)
- EIdx _ n a b
- | Refl <- dnPreservesTupIx n -> EIdx ext n (dfwdDN a) (dfwdDN b)
+ EIdx _ a b
+ | STArr n _ <- typeOf a
+ , Refl <- dnPreservesTupIx n
+ -> EIdx ext (dfwdDN a) (dfwdDN b)
EShape _ e
| Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) -> EShape ext (dfwdDN e)
EOp _ op e -> dop op (dfwdDN e)
@@ -191,8 +193,8 @@ emap f arr =
let STArr n t = typeOf arr
in ELet ext arr $
EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $
- ELet ext (EIdx ext n (EVar ext (STArr n t) (IS IZ))
- (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ))
+ (EVar ext (tTup (sreplicate n tIx)) IZ)) $
weakenExpr (WCopy (WSink .> WSink)) f
ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
@@ -202,7 +204,7 @@ ezip a b =
in ELet ext a $
ELet ext (weakenExpr WSink b) $
EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $
- EPair ext (EIdx ext n (EVar ext (STArr n t1) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) IZ))
- (EIdx ext n (EVar ext (STArr n t2) (IS IZ))
- (EVar ext (tTup (sreplicate n tIx)) IZ))
+ EPair ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
+ (EIdx ext (EVar ext (STArr n t2) (IS IZ))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 7be1c4b..3fb5d7b 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -23,6 +23,7 @@ import Data.Char (isSpace)
import Data.Kind (Type)
import Data.Int (Int64)
import Data.IORef
+import GHC.Stack (HasCallStack)
import System.IO.Unsafe (unsafePerformIO)
import Debug.Trace
@@ -33,7 +34,6 @@ import CHAD.Types
import Data
import Interpreter.Rep
import Data.Bifunctor (bimap)
-import GHC.Stack (HasCallStack)
newtype AcM s a = AcM { unAcM :: IO a }
@@ -95,7 +95,9 @@ interpret' env = \case
EConst _ _ v -> return v
EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
- EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
+ EIdx _ a b
+ | STArr n _ <- typeOf a
+ -> arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b)
EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e
EOp _ op e -> interpretOp op <$> interpret' env e
EWith e1 e2 -> do
@@ -135,6 +137,7 @@ zeroD2 typ = case typ of
STPair _ _ -> Left ()
STEither _ _ -> Left ()
STMaybe _ -> Nothing
+ STArr SZ t -> arrayUnit (zeroD2 t)
STArr n _ -> emptyArray n
STScal sty -> case sty of
STI32 -> ()
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
index ed307c0..5c20183 100644
--- a/src/Interpreter/Rep.hs
+++ b/src/Interpreter/Rep.hs
@@ -25,6 +25,7 @@ type family RepAcSparse t where
RepAcSparse (TPair a b) = IORef (RepAcSparse a, RepAcSparse b)
RepAcSparse (TEither a b) = TypeError (Text "Non-sparse coproduct is not a monoid")
RepAcSparse (TMaybe t) = IORef (Maybe (RepAcDense t)) -- allow the value to be dense, because the Maybe's zero can be used for the contents
+ -- TODO: an empty array is invalid for a zero-dimensional array, so zero-dimensional arrays don't actually have an O(1) zero.
RepAcSparse (TArr n t) = IORef (Array n (RepAcSparse t)) -- empty array is zero
RepAcSparse (TScal sty) = IORef (ScalRep sty)
RepAcSparse (TAccum t) = TypeError (Text "RepAcSparse: Nested accumulators")
diff --git a/src/Language.hs b/src/Language.hs
index c2b844e..a025236 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -88,8 +88,8 @@ idx0 = NEIdx0
(.!) = NEIdx1
infixl 9 .!
-(!) :: KnownNat n => NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
-(!) = NEIdx knownNat
+(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
+(!) = NEIdx
infixl 9 !
shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 0945dd9..409d24d 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -52,7 +52,7 @@ data NExpr env t where
NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t)
NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t
NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
- NEIdx :: SNat n -> NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
+ NEIdx :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
NEOp :: SOp a t -> NExpr env a -> NExpr env t
@@ -131,7 +131,7 @@ fromNamedExpr val = \case
NEConst t x -> EConst ext t x
NEIdx0 e -> EIdx0 ext (go e)
NEIdx1 a b -> EIdx1 ext (go a) (go b)
- NEIdx n a b -> EIdx ext n (go a) (go b)
+ NEIdx a b -> EIdx ext (go a) (go b)
NEShape e -> EShape ext (go e)
NEOp op e -> EOp ext op (go e)
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 3f4c8e3..5829a8b 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -94,7 +94,7 @@ simplify' = \case
EConst _ t v -> EConst ext t v
EIdx0 _ e -> EIdx0 ext (simplify' e)
EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b)
- EIdx _ n a b -> EIdx ext n (simplify' a) (simplify' b)
+ EIdx _ a b -> EIdx ext (simplify' a) (simplify' b)
EShape _ e -> EShape ext (simplify' e)
EOp _ op e -> EOp ext op (simplify' e)
EWith e1 e2 -> EWith (simplify' e1) (let ?accumInScope = True in simplify' e2)
@@ -136,7 +136,7 @@ hasAdds = \case
EConst _ _ _ -> False
EIdx0 _ e -> hasAdds e
EIdx1 _ a b -> hasAdds a || hasAdds b
- EIdx _ _ a b -> hasAdds a || hasAdds b
+ EIdx _ a b -> hasAdds a || hasAdds b
EShape _ e -> hasAdds e
EOp _ _ e -> hasAdds e
EWith a b -> hasAdds a || hasAdds b
diff --git a/test/Main.hs b/test/Main.hs
index f779352..a3fa484 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -241,6 +241,9 @@ tests = checkParallel $ Group "AD"
let_ #q (pair (snd_ #p * fst_ #p + #y) #x) $
fst_ #q * #x + snd_ #q * fst_ #p)
+ ,("build0 const", adTest $ fromNamed $ lambda @(TScal TF64) #x $ body $
+ idx0 $ build SZ nil $ #idx :-> const_ 0.0)
+
,("build0", adTest $ fromNamed $ lambda @(TArr N0 _) #x $ body $
idx0 $
build SZ (shape #x) $ #idx :-> #x ! #idx)