summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-11-07 23:58:03 +0100
committerTom Smeding <tom@tomsmeding.com>2024-11-07 23:58:03 +0100
commit58d4d0b47f5e609e21132f48b727de37d06b6777 (patch)
tree2339f67037ab37d26d5f3a50e30b005cc0bb7015
parent92ddb2263ae495c229badcc209c76a1252bd2752 (diff)
Remove build1
-rw-r--r--src/AST.hs3
-rw-r--r--src/AST/Count.hs1
-rw-r--r--src/AST/Pretty.hs7
-rw-r--r--src/CHAD.hs131
-rw-r--r--src/ForwardAD/DualNumbers.hs1
-rw-r--r--src/Interpreter.hs4
-rw-r--r--src/Language.hs23
-rw-r--r--src/Language/AST.hs21
-rw-r--r--src/Simplify.hs2
9 files changed, 35 insertions, 158 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 60fc5ad..f603443 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -82,7 +82,6 @@ data Expr x env t where
-- array operations
EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
- EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t)
EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
EFold1Inner :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
@@ -190,7 +189,6 @@ typeOf = \case
EMaybe _ e _ _ -> typeOf e
EConstArr _ n t _ -> STArr n (STScal t)
- EBuild1 _ _ e -> STArr (SS SZ) (typeOf e)
EBuild _ n _ e -> STArr n (typeOf e)
EFold1Inner _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
@@ -265,7 +263,6 @@ subst' f w = \case
EJust x e -> EJust x (subst' f w e)
EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e)
EConstArr x n t a -> EConstArr x n t a
- EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)
EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
EFold1Inner x a b c -> EFold1Inner x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
ESum1Inner x e -> ESum1Inner x (subst' f w e)
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index f3e3d74..71b38b1 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -114,7 +114,6 @@ occCountGeneral onehot unpush alter many = go WId
EJust _ e -> re e
EMaybe _ a b e -> re a <> re1 b <> re e
EConstArr{} -> mempty
- EBuild1 _ a b -> re a <> many (re1 b)
EBuild _ _ a b -> re a <> many (re1 b)
EFold1Inner _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c
ESum1Inner _ e -> re e
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 677c767..76424fe 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -110,13 +110,6 @@ ppExpr' d val = \case
EConstArr _ _ ty v
| Dict <- scalRepIsShow ty -> return $ showsPrec d v
- EBuild1 _ a b -> do
- a' <- ppExpr' 11 val a
- name <- genNameIfUsedIn' "i" (STScal STI64) IZ b
- b' <- ppExpr' 0 (Const name `SCons` val) b
- return $ showParen (d > 10) $
- showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")"
-
EBuild _ n a b -> do
a' <- ppExpr' 11 val a
name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b
diff --git a/src/CHAD.hs b/src/CHAD.hs
index ffbdcac..8080ec0 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -192,57 +192,6 @@ reconstructBindings binds tape =
,sreverse (stapeUnfoldings binds))
---------------------------------- VECTORISATION --------------------------------
--- Currently only used in D[build1], should be removed.
-
-{-
-type family Vectorise n list where
- Vectorise _ '[] = '[]
- Vectorise n (t : ts) = TArr n t : Vectorise n ts
-
-vectoriseEnv :: SNat n -> SList STy env -> SList STy (Vectorise n env)
-vectoriseEnv _ SNil = SNil
-vectoriseEnv n (t `SCons` env) = STArr n t `SCons` vectoriseEnv n env
-
-vectoriseIdx :: Idx binds t -> Idx (Vectorise n binds) (TArr n t)
-vectoriseIdx IZ = IZ
-vectoriseIdx (IS i) = IS (vectoriseIdx i)
-
-vectoriseExpr :: forall prefix binds env t f.
- SList f prefix
- -> SList STy binds
- -> SList f env
- -> Ex (Append prefix (Append binds env)) t
- -> Ex (TIx : Append prefix (Append (Vectorise (S Z) binds) env)) t
-vectoriseExpr prefix binds env =
- let wTarget :: Layout True ['("ix", '[TIx]), '("pre", prefix), '("vbinds", Vectorise (S Z) binds), '("env", env)] e
- -> e :> TIx : Append prefix (Append (Vectorise (S Z) binds) env)
- wTarget layout =
- autoWeak (#ix (auto1 @TIx) &. #pre prefix &. #vbinds (vectoriseEnv (SS SZ) binds) &. #env env)
- layout
- (#ix :++: #pre :++: #vbinds :++: #env)
- in
- subst $ \_ t i ->
- case splitIdx @(Append binds env) prefix i of
- Left iPre -> EVar ext t (wTarget #pre @> iPre)
- Right i' ->
- case splitIdx @env binds i' of
- Left iBinds ->
- EIdx0 ext $
- EIdx1 ext (EVar ext (STArr (SS SZ) t) (wTarget #vbinds @> vectoriseIdx iBinds))
- (EVar ext tIx IZ)
- Right iEnv -> EVar ext t (wTarget #env @> iEnv)
-
-vectorise1Binds :: forall env binds. SList STy env -> Idx env TIx -> Bindings Ex env binds -> Bindings Ex env (Vectorise (S Z) binds)
-vectorise1Binds _ _ BTop = BTop
-vectorise1Binds env n (bs `BPush` (t, e)) =
- let bs' = vectorise1Binds env n bs
- e' = EBuild1 ext (EVar ext tIx (sinkWithBindings bs' @> n))
- (vectoriseExpr SNil (bindingsBinds bs) env e)
- in bs' `BPush` (STArr (SS SZ) t, e')
--}
-
-
---------------------- ENVIRONMENT DESCRIPTION AND STORAGE ---------------------
type Storage :: Symbol -> Type
@@ -939,86 +888,6 @@ drev des = \case
(subenvNone (select SMerge des))
(ENil ext)
- EBuild1{} -> error "CHAD of EBuild1: Please use EBuild instead"
- {-
- -- TODO: either remove EBuilds1 entirely or rewrite it to work with an array of tapes instead of a vectorised tape
- EBuild1 _ ne (orige :: Ex _ eltty)
- | Ret (ne0 :: Bindings _ _ ne_binds) ne1 _ _ <- drev des ne -- allowed to ignore ne2 here because ne has a discrete result
- , let eltty = typeOf orige ->
- deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') ->
- let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in
- subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
- accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro ->
- case drev (prodes `DPush` (tIx, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) e1 sub e2 ->
- case assertSubenvEmpty sub of { Refl ->
- let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv usedDes) IZ e0 in
- Ret (bconcat (ne0 `BPush` (tIx, ne1))
- (fst (weakenBindings weakenExpr (WCopy (wSinksAnd (bindingsBinds ne0) (wUndoSubenv subD1eUsed))) ve0)))
- (EBuild1 ext
- (weakenExpr (autoWeak (#ve0 (bindingsBinds ve0)
- &. #binds (tIx `SCons` bindingsBinds ne0)
- &. #tl (sD1eEnv des))
- #binds
- ((#ve0 :++: #binds) :++: #tl))
- (EVar ext tIx IZ))
- (subst (\_ t i -> case splitIdx @(TIx : D1E env') (bindingsBinds e0) i of
- Left ibind ->
- let ibind' =
- autoWeak (#ix (auto1 @TIx)
- &. #ve0 (bindingsBinds ve0)
- &. #binds (tIx `SCons` bindingsBinds ne0)
- &. #tl (sD1eEnv des))
- #ve0
- (#ix :++: (#ve0 :++: #binds) :++: #tl)
- @> vectoriseIdx ibind
- in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t) ibind')
- (EVar ext tIx IZ))
- Right IZ -> EVar ext tIx IZ -- build lambda index argument
- Right (IS ienv) -> EVar ext t (IS (wSinksAnd (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) (wUndoSubenv subD1eUsed) @> ienv)))
- e1))
- (subenvCompose subMergeUsed proSub)
- (ELet ext
- (uninvertTup (d2e envPro) (STArr (SS SZ) STNil) $
- makeAccumulators @_ @_ @(TArr (S Z) TNil) envPro $
- EBuild1 ext
- (weakenExpr (autoWeak (#ve0 (bindingsBinds ve0)
- &. #pro (d2ace envPro)
- &. #d (auto1 @(D2 t))
- &. #binds (tIx `SCons` bindingsBinds ne0)
- &. #tl (d2ace (select SAccum des)))
- #binds
- (#pro :++: #d :++: (#ve0 :++: #binds) :++: #tl))
- (EVar ext tIx IZ))
- (ELet ext (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (d2 eltty))
- (IS (wSinks @(TArr (S Z) (D2 eltty) : Append (Append (Vectorise (S Z) e_binds) (TIx : ne_binds)) (D2AcE (Select env sto "accum")))
- (d2ace envPro)
- @> IZ)))
- (EVar ext tIx IZ))) $
- weakenExpr (autoWeak (#i (auto1 @TIx)
- &. #dpro (d2ace envPro)
- &. #d (d2 eltty `SCons` SNil)
- &. #darr (STArr (SS SZ) (d2 eltty) `SCons` SNil)
- &. #n (auto1 @TIx)
- &. #vbinds (bindingsBinds ve0)
- &. #ne0 (bindingsBinds ne0)
- &. #tl (d2ace (select SAccum des)))
- (#i :++: (#dpro :++: #d) :++: #vbinds :++: #tl)
- (#d :++: #i :++: #dpro :++: #darr :++: (#vbinds :++: #n :++: #ne0) :++: #tl)) $
- vectoriseExpr (sappend (d2ace envPro) (d2 eltty `SCons` SNil)) (bindingsBinds e0) (d2ace (select SAccum des)) $
- weakenExpr (autoWeak (#dpro (d2ace envPro)
- &. #d (d2 eltty `SCons` SNil)
- &. #binds (bindingsBinds e0)
- &. #tl (d2ace (select SAccum des)))
- (#dpro :++: #d :++: #binds :++: #tl)
- ((#dpro :++: #d) :++: #binds :++: #tl)) $
- weakenExpr (wCopies (d2ace envPro) (WCopy @(D2 eltty) (wCopies (bindingsBinds e0) (wUndoSubenv subAccumUsed)))) $
- weakenExpr (wPro (bindingsBinds e0)) $
- e2)) $
- ELet ext (ENil ext) $
- ESnd ext (EVar ext (STPair (STArr (SS SZ) STNil) (tTup (d2e envPro))) (IS IZ)))
- }}
- -}
-
-- TODO: merge the e0 and e1 builds in a single build just like they are merged into a single case in D[case]0, then it can really store only the parts that need to be preserved until D[build]2
EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty)
| Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des she -- allowed to ignore she2 here because she has a discrete result
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 8b4acb3..8e84378 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -152,7 +152,6 @@ dfwdDN = \case
(emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0))
(EConstArr ext n t x))
(EConstArr ext n t x)
- EBuild1 _ a b -> EBuild1 ext (dfwdDN a) (dfwdDN b)
EBuild _ n a b
| Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b)
EFold1Inner _ a b c -> EFold1Inner ext (dfwdDN a) (dfwdDN b) (dfwdDN c)
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 36543e9..abc9800 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -90,10 +90,6 @@ interpret'Rec env = \case
EJust _ e -> Just <$> interpret' env e
EMaybe _ a b e -> maybe (interpret' env a) (\x -> interpret' (Value x `SCons` env) b) =<< interpret' env e
EConstArr _ _ _ v -> return v
- EBuild1 _ a b -> do
- n <- fromIntegral @Int64 @Int <$> interpret' env a
- arrayGenerateLinM (ShNil `ShCons` n)
- (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b)
EBuild _ dim a b -> do
sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a
arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b)
diff --git a/src/Language.hs b/src/Language.hs
index a1b3d8b..e8dc89f 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeApplications #-}
module Language (
fromNamed,
NExpr,
@@ -65,7 +66,18 @@ constArr_ x =
Dict -> NEConstArr knownNat ty x
build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t)
-build1 a (v :-> b) = NEBuild1 a v b
+build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b))
+
+build2 :: NExpr env TIx -> NExpr env TIx
+ -> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t)
+ -> NExpr env (TArr (S (S Z)) t)
+build2 a1 a2 (v1 :-> v2 :-> b) =
+ NEBuild (SS (SS SZ))
+ (pair (pair nil a1) a2)
+ #idx
+ (let_ v1 (snd_ (fst_ #idx)) $
+ let_ v2 (NEDrop SZ (snd_ #idx)) $
+ NEDrop (SS (SS SZ)) b)
build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t)
build n a (v :-> b) = NEBuild n a v b
@@ -131,9 +143,6 @@ infix 4 .>=
not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool)
not_ = oper ONot
--- | The "_" variables in scope are unusable and should be ignored. With a
--- weakening function on NExprs they could be hidden.
---
--- The first alternative is the True case; the second is the False case.
-if_ :: NExpr env (TScal TBool) -> NExpr ('("_", TNil) : env) t -> NExpr ('("_", TNil) : env) t -> NExpr env t
-if_ e a b = case_ (oper OIf e) (#_ :-> a) (#_ :-> b)
+-- | The first alternative is the True case; the second is the False case.
+if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t
+if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b)
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 3b04bec..f5203e9 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -30,6 +30,9 @@ data NExpr env t where
NEVar :: Lookup name env ~ t => Var name t -> NExpr env t
NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t
+ -- environment management
+ NEDrop :: SNat i -> NExpr (DropNth i env) t -> NExpr env t
+
-- base types
NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b)
NEFst :: NExpr env (TPair a b) -> NExpr env a
@@ -41,7 +44,6 @@ data NExpr env t where
-- array operations
NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
- NEBuild1 :: NExpr env TIx -> Var name TIx -> NExpr ('(name, TIx) : env) t -> NExpr env (TArr (S Z) t)
NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t)
NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
@@ -68,6 +70,10 @@ type family Lookup name env where
Lookup name ('(name, t) : env) = t
Lookup name (_ : env) = Lookup name env
+type family DropNth i env where
+ DropNth Z (_ : env) = env
+ DropNth (S i) (p : env) = p : DropNth i env
+
data Var name t = Var (SSymbol name) (STy t)
deriving (Show)
@@ -135,6 +141,8 @@ fromNamedExpr val = \case
\expression to De Bruijn expression"
NELet n a b -> ELet ext (go a) (lambda val n b)
+ NEDrop i e -> weakenExpr (dropNthW i val) (fromNamedExpr (dropNth i val) e)
+
NEPair a b -> EPair ext (go a) (go b)
NEFst e -> EFst ext (go e)
NESnd e -> ESnd ext (go e)
@@ -144,7 +152,6 @@ fromNamedExpr val = \case
NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b)
NEConstArr n t x -> EConstArr ext n t x
- NEBuild1 a n b -> EBuild1 ext (go a) (lambda val n b)
NEBuild k a n b -> EBuild ext k (go a) (lambda val n b)
NEFold1Inner n1 n2 a b c -> EFold1Inner ext (lambda2 val n1 n2 a) (go b) (go c)
NESum1Inner e -> ESum1Inner ext (go e)
@@ -185,3 +192,13 @@ fromNamedExpr val = \case
injectWrapLet e (arg `SCons` args) =
injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e)
args
+
+dropNth :: SNat i -> NEnv env -> NEnv (DropNth i env)
+dropNth SZ (val `NPush` _) = val
+dropNth (SS i) (val `NPush` p) = dropNth i val `NPush` p
+dropNth _ NTop = error "DropNth: index out of range"
+
+dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env
+dropNthW SZ (_ `NPush` _) = WSink
+dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val)
+dropNthW _ NTop = error "DropNth: index out of range"
diff --git a/src/Simplify.hs b/src/Simplify.hs
index f8b4b63..66a4004 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -110,7 +110,6 @@ simplify' = \case
EJust _ e -> EJust ext <$> simplify' e
EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e
EConstArr _ n t v -> pure $ EConstArr ext n t v
- EBuild1 _ a b -> EBuild1 ext <$> simplify' a <*> simplify' b
EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b
EFold1Inner _ a b c -> EFold1Inner ext <$> simplify' a <*> simplify' b <*> simplify' c
ESum1Inner _ e -> ESum1Inner ext <$> simplify' e
@@ -156,7 +155,6 @@ hasAdds = \case
EJust _ e -> hasAdds e
EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e
EConstArr _ _ _ _ -> False
- EBuild1 _ a b -> hasAdds a || hasAdds b
EBuild _ _ a b -> hasAdds a || hasAdds b
EFold1Inner _ a b c -> hasAdds a || hasAdds b || hasAdds c
ESum1Inner _ e -> hasAdds e