diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-07 23:58:03 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-07 23:58:03 +0100 |
commit | 58d4d0b47f5e609e21132f48b727de37d06b6777 (patch) | |
tree | 2339f67037ab37d26d5f3a50e30b005cc0bb7015 | |
parent | 92ddb2263ae495c229badcc209c76a1252bd2752 (diff) |
Remove build1
-rw-r--r-- | src/AST.hs | 3 | ||||
-rw-r--r-- | src/AST/Count.hs | 1 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 7 | ||||
-rw-r--r-- | src/CHAD.hs | 131 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 1 | ||||
-rw-r--r-- | src/Interpreter.hs | 4 | ||||
-rw-r--r-- | src/Language.hs | 23 | ||||
-rw-r--r-- | src/Language/AST.hs | 21 | ||||
-rw-r--r-- | src/Simplify.hs | 2 |
9 files changed, 35 insertions, 158 deletions
@@ -82,7 +82,6 @@ data Expr x env t where -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) - EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t) EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) EFold1Inner :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) @@ -190,7 +189,6 @@ typeOf = \case EMaybe _ e _ _ -> typeOf e EConstArr _ n t _ -> STArr n (STScal t) - EBuild1 _ _ e -> STArr (SS SZ) (typeOf e) EBuild _ n _ e -> STArr n (typeOf e) EFold1Inner _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t @@ -265,7 +263,6 @@ subst' f w = \case EJust x e -> EJust x (subst' f w e) EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e) EConstArr x n t a -> EConstArr x n t a - EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b) EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) EFold1Inner x a b c -> EFold1Inner x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) ESum1Inner x e -> ESum1Inner x (subst' f w e) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index f3e3d74..71b38b1 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -114,7 +114,6 @@ occCountGeneral onehot unpush alter many = go WId EJust _ e -> re e EMaybe _ a b e -> re a <> re1 b <> re e EConstArr{} -> mempty - EBuild1 _ a b -> re a <> many (re1 b) EBuild _ _ a b -> re a <> many (re1 b) EFold1Inner _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c ESum1Inner _ e -> re e diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 677c767..76424fe 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -110,13 +110,6 @@ ppExpr' d val = \case EConstArr _ _ ty v | Dict <- scalRepIsShow ty -> return $ showsPrec d v - EBuild1 _ a b -> do - a' <- ppExpr' 11 val a - name <- genNameIfUsedIn' "i" (STScal STI64) IZ b - b' <- ppExpr' 0 (Const name `SCons` val) b - return $ showParen (d > 10) $ - showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")" - EBuild _ n a b -> do a' <- ppExpr' 11 val a name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b diff --git a/src/CHAD.hs b/src/CHAD.hs index ffbdcac..8080ec0 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -192,57 +192,6 @@ reconstructBindings binds tape = ,sreverse (stapeUnfoldings binds)) ---------------------------------- VECTORISATION -------------------------------- --- Currently only used in D[build1], should be removed. - -{- -type family Vectorise n list where - Vectorise _ '[] = '[] - Vectorise n (t : ts) = TArr n t : Vectorise n ts - -vectoriseEnv :: SNat n -> SList STy env -> SList STy (Vectorise n env) -vectoriseEnv _ SNil = SNil -vectoriseEnv n (t `SCons` env) = STArr n t `SCons` vectoriseEnv n env - -vectoriseIdx :: Idx binds t -> Idx (Vectorise n binds) (TArr n t) -vectoriseIdx IZ = IZ -vectoriseIdx (IS i) = IS (vectoriseIdx i) - -vectoriseExpr :: forall prefix binds env t f. - SList f prefix - -> SList STy binds - -> SList f env - -> Ex (Append prefix (Append binds env)) t - -> Ex (TIx : Append prefix (Append (Vectorise (S Z) binds) env)) t -vectoriseExpr prefix binds env = - let wTarget :: Layout True ['("ix", '[TIx]), '("pre", prefix), '("vbinds", Vectorise (S Z) binds), '("env", env)] e - -> e :> TIx : Append prefix (Append (Vectorise (S Z) binds) env) - wTarget layout = - autoWeak (#ix (auto1 @TIx) &. #pre prefix &. #vbinds (vectoriseEnv (SS SZ) binds) &. #env env) - layout - (#ix :++: #pre :++: #vbinds :++: #env) - in - subst $ \_ t i -> - case splitIdx @(Append binds env) prefix i of - Left iPre -> EVar ext t (wTarget #pre @> iPre) - Right i' -> - case splitIdx @env binds i' of - Left iBinds -> - EIdx0 ext $ - EIdx1 ext (EVar ext (STArr (SS SZ) t) (wTarget #vbinds @> vectoriseIdx iBinds)) - (EVar ext tIx IZ) - Right iEnv -> EVar ext t (wTarget #env @> iEnv) - -vectorise1Binds :: forall env binds. SList STy env -> Idx env TIx -> Bindings Ex env binds -> Bindings Ex env (Vectorise (S Z) binds) -vectorise1Binds _ _ BTop = BTop -vectorise1Binds env n (bs `BPush` (t, e)) = - let bs' = vectorise1Binds env n bs - e' = EBuild1 ext (EVar ext tIx (sinkWithBindings bs' @> n)) - (vectoriseExpr SNil (bindingsBinds bs) env e) - in bs' `BPush` (STArr (SS SZ) t, e') --} - - ---------------------- ENVIRONMENT DESCRIPTION AND STORAGE --------------------- type Storage :: Symbol -> Type @@ -939,86 +888,6 @@ drev des = \case (subenvNone (select SMerge des)) (ENil ext) - EBuild1{} -> error "CHAD of EBuild1: Please use EBuild instead" - {- - -- TODO: either remove EBuilds1 entirely or rewrite it to work with an array of tapes instead of a vectorised tape - EBuild1 _ ne (orige :: Ex _ eltty) - | Ret (ne0 :: Bindings _ _ ne_binds) ne1 _ _ <- drev des ne -- allowed to ignore ne2 here because ne has a discrete result - , let eltty = typeOf orige -> - deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> - let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in - subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> - accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> - case drev (prodes `DPush` (tIx, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) e1 sub e2 -> - case assertSubenvEmpty sub of { Refl -> - let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv usedDes) IZ e0 in - Ret (bconcat (ne0 `BPush` (tIx, ne1)) - (fst (weakenBindings weakenExpr (WCopy (wSinksAnd (bindingsBinds ne0) (wUndoSubenv subD1eUsed))) ve0))) - (EBuild1 ext - (weakenExpr (autoWeak (#ve0 (bindingsBinds ve0) - &. #binds (tIx `SCons` bindingsBinds ne0) - &. #tl (sD1eEnv des)) - #binds - ((#ve0 :++: #binds) :++: #tl)) - (EVar ext tIx IZ)) - (subst (\_ t i -> case splitIdx @(TIx : D1E env') (bindingsBinds e0) i of - Left ibind -> - let ibind' = - autoWeak (#ix (auto1 @TIx) - &. #ve0 (bindingsBinds ve0) - &. #binds (tIx `SCons` bindingsBinds ne0) - &. #tl (sD1eEnv des)) - #ve0 - (#ix :++: (#ve0 :++: #binds) :++: #tl) - @> vectoriseIdx ibind - in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t) ibind') - (EVar ext tIx IZ)) - Right IZ -> EVar ext tIx IZ -- build lambda index argument - Right (IS ienv) -> EVar ext t (IS (wSinksAnd (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) (wUndoSubenv subD1eUsed) @> ienv))) - e1)) - (subenvCompose subMergeUsed proSub) - (ELet ext - (uninvertTup (d2e envPro) (STArr (SS SZ) STNil) $ - makeAccumulators @_ @_ @(TArr (S Z) TNil) envPro $ - EBuild1 ext - (weakenExpr (autoWeak (#ve0 (bindingsBinds ve0) - &. #pro (d2ace envPro) - &. #d (auto1 @(D2 t)) - &. #binds (tIx `SCons` bindingsBinds ne0) - &. #tl (d2ace (select SAccum des))) - #binds - (#pro :++: #d :++: (#ve0 :++: #binds) :++: #tl)) - (EVar ext tIx IZ)) - (ELet ext (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (d2 eltty)) - (IS (wSinks @(TArr (S Z) (D2 eltty) : Append (Append (Vectorise (S Z) e_binds) (TIx : ne_binds)) (D2AcE (Select env sto "accum"))) - (d2ace envPro) - @> IZ))) - (EVar ext tIx IZ))) $ - weakenExpr (autoWeak (#i (auto1 @TIx) - &. #dpro (d2ace envPro) - &. #d (d2 eltty `SCons` SNil) - &. #darr (STArr (SS SZ) (d2 eltty) `SCons` SNil) - &. #n (auto1 @TIx) - &. #vbinds (bindingsBinds ve0) - &. #ne0 (bindingsBinds ne0) - &. #tl (d2ace (select SAccum des))) - (#i :++: (#dpro :++: #d) :++: #vbinds :++: #tl) - (#d :++: #i :++: #dpro :++: #darr :++: (#vbinds :++: #n :++: #ne0) :++: #tl)) $ - vectoriseExpr (sappend (d2ace envPro) (d2 eltty `SCons` SNil)) (bindingsBinds e0) (d2ace (select SAccum des)) $ - weakenExpr (autoWeak (#dpro (d2ace envPro) - &. #d (d2 eltty `SCons` SNil) - &. #binds (bindingsBinds e0) - &. #tl (d2ace (select SAccum des))) - (#dpro :++: #d :++: #binds :++: #tl) - ((#dpro :++: #d) :++: #binds :++: #tl)) $ - weakenExpr (wCopies (d2ace envPro) (WCopy @(D2 eltty) (wCopies (bindingsBinds e0) (wUndoSubenv subAccumUsed)))) $ - weakenExpr (wPro (bindingsBinds e0)) $ - e2)) $ - ELet ext (ENil ext) $ - ESnd ext (EVar ext (STPair (STArr (SS SZ) STNil) (tTup (d2e envPro))) (IS IZ))) - }} - -} - -- TODO: merge the e0 and e1 builds in a single build just like they are merged into a single case in D[case]0, then it can really store only the parts that need to be preserved until D[build]2 EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty) | Ret (she0 :: Bindings _ _ she_binds) _ she1 _ _ <- drev des she -- allowed to ignore she2 here because she has a discrete result diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 8b4acb3..8e84378 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -152,7 +152,6 @@ dfwdDN = \case (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) (EConstArr ext n t x)) (EConstArr ext n t x) - EBuild1 _ a b -> EBuild1 ext (dfwdDN a) (dfwdDN b) EBuild _ n a b | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) EFold1Inner _ a b c -> EFold1Inner ext (dfwdDN a) (dfwdDN b) (dfwdDN c) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 36543e9..abc9800 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -90,10 +90,6 @@ interpret'Rec env = \case EJust _ e -> Just <$> interpret' env e EMaybe _ a b e -> maybe (interpret' env a) (\x -> interpret' (Value x `SCons` env) b) =<< interpret' env e EConstArr _ _ _ v -> return v - EBuild1 _ a b -> do - n <- fromIntegral @Int64 @Int <$> interpret' env a - arrayGenerateLinM (ShNil `ShCons` n) - (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b) EBuild _ dim a b -> do sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx ixUncons dim idx) `SCons` env) b) diff --git a/src/Language.hs b/src/Language.hs index a1b3d8b..e8dc89f 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -3,6 +3,7 @@ {-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} module Language ( fromNamed, NExpr, @@ -65,7 +66,18 @@ constArr_ x = Dict -> NEConstArr knownNat ty x build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t) -build1 a (v :-> b) = NEBuild1 a v b +build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b)) + +build2 :: NExpr env TIx -> NExpr env TIx + -> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t) + -> NExpr env (TArr (S (S Z)) t) +build2 a1 a2 (v1 :-> v2 :-> b) = + NEBuild (SS (SS SZ)) + (pair (pair nil a1) a2) + #idx + (let_ v1 (snd_ (fst_ #idx)) $ + let_ v2 (NEDrop SZ (snd_ #idx)) $ + NEDrop (SS (SS SZ)) b) build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t) build n a (v :-> b) = NEBuild n a v b @@ -131,9 +143,6 @@ infix 4 .>= not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) not_ = oper ONot --- | The "_" variables in scope are unusable and should be ignored. With a --- weakening function on NExprs they could be hidden. --- --- The first alternative is the True case; the second is the False case. -if_ :: NExpr env (TScal TBool) -> NExpr ('("_", TNil) : env) t -> NExpr ('("_", TNil) : env) t -> NExpr env t -if_ e a b = case_ (oper OIf e) (#_ :-> a) (#_ :-> b) +-- | The first alternative is the True case; the second is the False case. +if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t +if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b) diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 3b04bec..f5203e9 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -30,6 +30,9 @@ data NExpr env t where NEVar :: Lookup name env ~ t => Var name t -> NExpr env t NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t + -- environment management + NEDrop :: SNat i -> NExpr (DropNth i env) t -> NExpr env t + -- base types NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) NEFst :: NExpr env (TPair a b) -> NExpr env a @@ -41,7 +44,6 @@ data NExpr env t where -- array operations NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) - NEBuild1 :: NExpr env TIx -> Var name TIx -> NExpr ('(name, TIx) : env) t -> NExpr env (TArr (S Z) t) NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) @@ -68,6 +70,10 @@ type family Lookup name env where Lookup name ('(name, t) : env) = t Lookup name (_ : env) = Lookup name env +type family DropNth i env where + DropNth Z (_ : env) = env + DropNth (S i) (p : env) = p : DropNth i env + data Var name t = Var (SSymbol name) (STy t) deriving (Show) @@ -135,6 +141,8 @@ fromNamedExpr val = \case \expression to De Bruijn expression" NELet n a b -> ELet ext (go a) (lambda val n b) + NEDrop i e -> weakenExpr (dropNthW i val) (fromNamedExpr (dropNth i val) e) + NEPair a b -> EPair ext (go a) (go b) NEFst e -> EFst ext (go e) NESnd e -> ESnd ext (go e) @@ -144,7 +152,6 @@ fromNamedExpr val = \case NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b) NEConstArr n t x -> EConstArr ext n t x - NEBuild1 a n b -> EBuild1 ext (go a) (lambda val n b) NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) NEFold1Inner n1 n2 a b c -> EFold1Inner ext (lambda2 val n1 n2 a) (go b) (go c) NESum1Inner e -> ESum1Inner ext (go e) @@ -185,3 +192,13 @@ fromNamedExpr val = \case injectWrapLet e (arg `SCons` args) = injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e) args + +dropNth :: SNat i -> NEnv env -> NEnv (DropNth i env) +dropNth SZ (val `NPush` _) = val +dropNth (SS i) (val `NPush` p) = dropNth i val `NPush` p +dropNth _ NTop = error "DropNth: index out of range" + +dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env +dropNthW SZ (_ `NPush` _) = WSink +dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val) +dropNthW _ NTop = error "DropNth: index out of range" diff --git a/src/Simplify.hs b/src/Simplify.hs index f8b4b63..66a4004 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -110,7 +110,6 @@ simplify' = \case EJust _ e -> EJust ext <$> simplify' e EMaybe _ a b e -> EMaybe ext <$> simplify' a <*> simplify' b <*> simplify' e EConstArr _ n t v -> pure $ EConstArr ext n t v - EBuild1 _ a b -> EBuild1 ext <$> simplify' a <*> simplify' b EBuild _ n a b -> EBuild ext n <$> simplify' a <*> simplify' b EFold1Inner _ a b c -> EFold1Inner ext <$> simplify' a <*> simplify' b <*> simplify' c ESum1Inner _ e -> ESum1Inner ext <$> simplify' e @@ -156,7 +155,6 @@ hasAdds = \case EJust _ e -> hasAdds e EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e EConstArr _ _ _ _ -> False - EBuild1 _ a b -> hasAdds a || hasAdds b EBuild _ _ a b -> hasAdds a || hasAdds b EFold1Inner _ a b c -> hasAdds a || hasAdds b || hasAdds c ESum1Inner _ e -> hasAdds e |