aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-11-04 23:09:21 +0100
committerTom Smeding <tom@tomsmeding.com>2025-11-04 23:09:21 +0100
commit57779d4303f377004705c8da06a5ac46177950b2 (patch)
tree0407089403d3d5c2de778c1aab7aed8adf2d01c0
parent351667a3ff14c96a8dfe3a2f1dd76b6e1a996542 (diff)
drevLambda works, TODO D[map]HEADmaster
-rw-r--r--src/AST.hs8
-rw-r--r--src/AST/Count.hs26
-rw-r--r--src/AST/Pretty.hs14
-rw-r--r--src/AST/SplitLets.hs6
-rw-r--r--src/Analysis/Identity.hs10
-rw-r--r--src/CHAD.hs112
-rw-r--r--src/Compile.hs13
-rw-r--r--src/Interpreter.hs8
-rw-r--r--src/Language.hs32
-rw-r--r--src/Language/AST.hs41
10 files changed, 148 insertions, 122 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 873a8a5..ca6cdd1 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -65,7 +65,7 @@ data Expr x env t where
EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t)
-- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right)
- EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
+ EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)
@@ -79,7 +79,7 @@ data Expr x env t where
-- values in some implementation-defined order.
-- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs.
EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative
- -> Expr x (t1 : t1 : env) (TPair t1 b)
+ -> Expr x (TPair t1 t1 : env) (TPair t1 b)
-> Expr x env t1
-> Expr x env (TArr (S n) t1)
-> Expr x env (TPair (TArr n t1) -- normal primal fold output
@@ -403,7 +403,7 @@ subst' f w = \case
EConstArr x n t a -> EConstArr x n t a
EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b)
- EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
+ EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c)
ESum1Inner x e -> ESum1Inner x (subst' f w e)
EUnit x e -> EUnit x (subst' f w e)
EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b)
@@ -411,7 +411,7 @@ subst' f w = \case
EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)
EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b)
EZip x a b -> EZip x (subst' f w a) (subst' f w b)
- EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
+ EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c)
EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (subst' f w e)
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index bc02417..a53822d 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -560,22 +560,23 @@ occCountX initialS topexpr k = case topexpr of
EMap ext (mka (OccPush env' () s1)) (mkb env')
EFold1Inner _ commut a b c ->
- occCountX SsFull a $ \env1''' mka ->
- withSome (scaleMany (Some env1''')) $ \env1'' ->
- occEnvPop' env1'' $ \env1' s2 ->
- occEnvPop' env1' $ \env1 s1 ->
- let s0 = case s of
+ occCountX SsFull a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1' ->
+ let s1 = case s1' of
+ SsNone -> Some SsNone
+ SsPair' s1'a s1'b -> Some s1'a <> Some s1'b
+ s0 = case s of
SsNone -> Some SsNone
SsArr' s' -> Some s' in
- withSome (Some s1 <> Some s2 <> s0) $ \sElt ->
+ withSome (s1 <> s0) $ \sElt ->
occCountX sElt b $ \env2 mkb ->
- occCountX (SsArr sElt) c $ \env3 mkc ->
- withSome (Some env1 <> Some env2 <> Some env3) $ \env ->
+ occCountX (SsArr sElt) c $ \env3 mkc ->
+ withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
k env $ \env' ->
projectSmallerSubstruc (SsArr sElt) s $
EFold1Inner ext commut
(projectSmallerSubstruc SsFull sElt $
- mka (OccPush (OccPush env' () sElt) () sElt))
+ mka (OccPush env' () (SsPair sElt sElt)))
(mkb env') (mkc env')
ESum1Inner _ e -> handleReduction (ESum1Inner ext) e
@@ -665,7 +666,7 @@ occCountX initialS topexpr k = case topexpr of
elet (mapExt (\_ -> ext) e3) $
EPair ext
(EShape ext (evar IZ))
- (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1)))
+ (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1)))
(mapExt (\_ -> ext) (weakenExpr WSink e2))
(evar IZ))
in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex ->
@@ -675,15 +676,14 @@ occCountX initialS topexpr k = case topexpr of
-- If at least some of the additional stores are required, we need to keep this a mapAccum
SsPair' _ (SsArr' sB) ->
-- TODO: propagate usage of primals
- occCountX (SsPair SsFull sB) e1 $ \env1_2' mka ->
- occEnvPop' env1_2' $ \env1_1' _ ->
+ occCountX (SsPair SsFull sB) e1 $ \env1_1' mka ->
occEnvPop' env1_1' $ \env1' _ ->
occCountX SsFull e2 $ \env2 mkb ->
occCountX SsFull e3 $ \env3 mkc ->
withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
k env $ \env' ->
projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $
- EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull))
+ EFold1InnerD1 ext cm (mka (OccPush env' () SsFull))
(mkb env') (mkc env')
EFold1InnerD2 _ cm ef ebog ed ->
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 2c51b85..ecdaa88 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -213,14 +213,13 @@ ppExpr' d val expr = case expr of
ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b']
EFold1Inner _ cm a b c -> do
- name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
- name2 <- genNameIfUsedIn (typeOf a) IZ a
- a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
+ name <- genNameIfUsedIn (STPair (typeOf a) (typeOf a)) IZ a
+ a' <- ppExpr' 0 (Const name `SCons` val) a
b' <- ppExpr' 11 val b
c' <- ppExpr' 11 val c
let opname = "fold1i" ++ ppCommut cm
return $ ppParen (d > 10) $
- ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c']
ESum1Inner _ e -> do
e' <- ppExpr' 11 val e
@@ -254,14 +253,13 @@ ppExpr' d val expr = case expr of
return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2']
EFold1InnerD1 _ cm a b c -> do
- name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a
- name2 <- genNameIfUsedIn (typeOf b) IZ a
- a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
+ name <- genNameIfUsedIn (STPair (typeOf b) (typeOf b)) IZ a
+ a' <- ppExpr' 0 (Const name `SCons` val) a
b' <- ppExpr' 11 val b
c' <- ppExpr' 11 val c
let opname = "fold1iD1" ++ ppCommut cm
return $ ppParen (d > 10) $
- ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c']
EFold1InnerD2 _ cm ef ebog ed -> do
let STArr _ tB = typeOf ebog
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index d276e44..267dd87 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -34,10 +34,10 @@ splitLets' = \sub -> \case
in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c)
EFold1Inner x cm a b c ->
let STArr _ t1 = typeOf c
- in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
+ in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c)
EFold1InnerD1 x cm a b c ->
let STArr _ t1 = typeOf c
- in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
+ in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c)
EFold1InnerD2 x cm a b c ->
let STArr _ tB = typeOf b
STArr _ t2 = typeOf c
@@ -56,12 +56,14 @@ splitLets' = \sub -> \case
ELInr x t e -> ELInr x t (splitLets' sub e)
EConstArr x n t a -> EConstArr x n t a
EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b)
+ EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b)
ESum1Inner x e -> ESum1Inner x (splitLets' sub e)
EUnit x e -> EUnit x (splitLets' sub e)
EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b)
EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e)
EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e)
EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b)
+ EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (splitLets' sub e)
EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b)
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index 71da793..7b896a3 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -213,9 +213,8 @@ idana env expr = case expr of
EFold1Inner _ cm e1 e2 e3 -> do
let t1 = typeOf e1
- x1 <- genIds t1
- x2 <- genIds t1
- (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1
+ x1 <- genIds (STPair t1 t1)
+ (_, e1') <- idana (x1 `SCons` env) e1
(_, e2') <- idana env e2
(v3, e3') <- idana env e3
let VIArr _ (_ :< sh) = v3
@@ -268,9 +267,8 @@ idana env expr = case expr of
EFold1InnerD1 _ cm e1 e2 e3 -> do
let t1 = typeOf e2
- x1 <- genIds t1
- x2 <- genIds t1
- (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1
+ x1 <- genIds (STPair t1 t1)
+ (_, e1') <- idana (x1 `SCons` env) e1
(_, e2') <- idana env e2
(v3, e3') <- idana env e3
let VIArr _ sh'@(_ :< sh) = v3
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 72ce36d..9da5395 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1077,37 +1077,29 @@ drev des accumMap sd = \case
ESnd ext $
wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $
EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $
- -- the tape for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> sinkOverEnvPro @> IS IZ))
- (EVar ext shty IZ)) $
-- the cotangent for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> WSink .> sinkOverEnvPro @> IZ))
+ ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ))
+ (EVar ext shty IZ)) $
+ -- the tape for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
(EVar ext shty (IS IZ))) $
- weakenExpr (autoWeak library (#d :++: #tape :++: #pro :++: #d2acEnv)
- (#d :++: #tape :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv))
+ weakenExpr (autoWeak library (#tape :++: #d :++: #pro :++: #d2acEnv)
+ (#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv))
e2)
- EMap{} -> undefined
+ EMap{} -> error "TODO: CHAD EMap"
EFold1Inner _ commut origef ex₀ earr
| SpArr @_ @sdElt sdElt <- sd
, STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr
, Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil)
<- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil ->
- deleteUnused (descrList des) (occEnvPopSome (occEnvPopSome (occCountAll origef))) $ \(usedSub :: Subenv env env') ->
- let ef = unsafeWeakenWithSubenv (SEYesR (SEYesR usedSub)) origef in
- subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
- accumPromote (d2 eltty) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
- let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
- let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in
- let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in
- let (mergePrimalBindings', _) = weakenBindingsE (sinkWithBindings bindsx₀a) mergePrimalBindings in
- case drev (prodes `DPush` (eltty, Nothing, SMerge) `DPush` (eltty, Nothing, SMerge)) accumMapPro (spDense (d2M eltty)) ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 ->
- let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in
- let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf)))
+ drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) subEf wrapAccum ef2 ->
+ let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings bindsx₀a) proPrimalBinds in
+ let bogEltTy = STPair (STPair (d1 eltty) (d1 eltty)) (typeOf ef1tape)
+ bogTy = STArr (SS ndim) bogEltTy
primalTy = STPair (STArr ndim (d1 eltty)) bogTy
- zipPrimalTy = STPair (d1 eltty) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf)))
- library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil)
+ library = #xy (STPair (d1 eltty) (d1 eltty) `SCons` SNil)
&. #parr (auto1 @(TArr (S n) (D1 elt)))
&. #px₀ (auto1 @(D1 elt))
&. #px (auto1 @(D1 elt))
@@ -1118,70 +1110,52 @@ drev des accumMap sd = \case
&. #x₀abinds (bindingsBinds bindsx₀a)
&. #fbinds (bindingsBinds ef0)
&. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a)
- &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf)
- &. #ftape (auto1 @(Tape e_tape))
- &. #primalzip (zipPrimalTy `SCons` SNil)
- &. #efPrerebinds efPrerebinds
- &. #propr (d1e envPro)
+ &. #ftape (auto1 @ef_tape)
+ &. #bogelt (bogEltTy `SCons` SNil)
+ &. #propr (d1e provars)
&. #d1env (desD1E des)
- &. #d1env' (desD1E usedDes)
- &. #d2acUsed (d2ace (select SAccum usedDes))
&. #d2acEnv (d2ace (select SAccum des))
- &. #d2acPro (d2ace envPro)
+ &. #d2acPro (d2ace provars)
&. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro))))
wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in
subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a ->
- subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) $ \subx₀af _ _ plus_x₀a_f ->
- Ret (bconcat bindsx₀a mergePrimalBindings'
+ subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) $ \subx₀af _ _ plus_x₀a_f ->
+ Ret (bconcat bindsx₀a proPrimalBinds'
`bpush` weakenExpr wOverPrimalBindings ex₀1
`bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ)
`bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1
`bpush` EFold1InnerD1 ext commut
(let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in
- letBinds (fst (weakenBindingsE (autoWeak library
- (#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- layout)
- ef0)) $
- elet (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#fbinds :++: layout))
- ef1) $
- EPair ext
- (evar IZ)
+ letBinds (fst (weakenBindingsE (autoWeak library (#xy :++: #d1env) layout) ef0)) $
+ EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape)
+ (weakenExpr (autoWeak library (#fbinds :++: #xy :++: #d1env) (#fbinds :++: layout)) ef1)
(EPair ext
- (evar IZ)
- (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#px :++: #fbinds :++: layout)))))
+ (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: layout) @> IZ))
+ (weakenExpr (autoWeak library (#fbinds :++: #xy :++: #d1env) (#fbinds :++: layout)) ef1tape)))
(EVar ext (d1 eltty) (IS (IS IZ)))
(EVar ext (STArr (SS ndim) (d1 eltty)) IZ))
- (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro)))))))
+ (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e provars)))))))
(EFst ext (EVar ext primalTy IZ))
subx₀af
(let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in
elet
- (uninvertTup (d2e envPro) (STPair (STArr ndim (d2 eltty)) (STArr (SS ndim) (d2 eltty))) $
- makeAccumulators (autoWeak library #propr layout1) envPro $
- let layout2 = #d2acPro :++: layout1 in
- EFold1InnerD2 ext commut
- (elet (ESnd ext (ESnd ext (EVar ext zipPrimalTy (IS IZ)))) $
- elet (EFst ext (ESnd ext (EVar ext zipPrimalTy (IS (IS IZ))))) $
- elet (EFst ext (EVar ext zipPrimalTy (IS (IS (IS IZ))))) $
- letBinds (efRebinds (IS (IS IZ))) $
- let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #xy :++: #ftape :++: #d :++: #primalzip :++: layout2 in
- elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $
- weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3
- .> wPro (subList (bindingsBinds ef0) subtapeEf))
- ef2) $
- EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ)))
- (ezip
- (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ))
- (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))
- (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ)))
- (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
- (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $
+ (wrapAccum (autoWeak library #propr layout1) $
+ let layout2 = #d2acPro :++: layout1 in
+ EFold1InnerD2 ext commut
+ (elet (ESnd ext (EVar ext bogEltTy (IS IZ))) $
+ let layout3 = #ftape :++: #d :++: #bogelt :++: layout2 in
+ expandSparse (STPair eltty eltty) subEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $
+ weakenExpr (autoWeak library (#ftape :++: #d :++: #d2acPro :++: #d2acEnv) layout3) ef2)
+ (ESnd ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))
+ (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ)))
+ (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ))
+ (EFst ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))))) $
plus_x₀a_f
(plus_x₀_a
(elet (EIdx0 ext
(EFold1Inner ext Commut
- (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ))
+ (let t = STPair (d2 eltty) (d2 eltty)
+ in EPlus ext (d2M eltty) (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ)))
(EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ)))
(eflatten (EFst ext (EFst ext (evar IZ)))))) $
weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1))
@@ -1189,7 +1163,6 @@ drev des accumMap sd = \case
(weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $
subst0 (ESnd ext (EFst ext (evar IZ))) ea2))
(ESnd ext (evar IZ)))
- }
EUnit _ e
| SpArr sdElt <- sd
@@ -1213,9 +1186,8 @@ drev des accumMap sd = \case
(EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1)
sub
(ELet ext (EFold1Inner ext Commut
- (sparsePlus (d2M eltty) sdElt'
- (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ))
- (EVar ext (applySparse sdElt' (d2 eltty)) IZ))
+ (let t = STPair (applySparse sdElt' (d2 eltty)) (applySparse sdElt' (d2 eltty))
+ in sparsePlus (d2M eltty) sdElt' (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ)))
(inj2 (ENil ext))
(emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $
weakenExpr (WCopy WSink) e2)
@@ -1494,7 +1466,7 @@ drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False)
D1E provars :> env'
-> Ex (Append (D2AcE provars) env') b
-> Ex ( env') (TPair b (Tup (D2E provars))))
- -> Ex (dt : tape : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a'
+ -> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a'
-> r)
-> r
drevLambda des accumMap (argty, argsto) sd origef k =
@@ -1535,10 +1507,10 @@ drevLambda des accumMap (argty, argsto) sd origef k =
uninvertTup (d2e envPro) (typeOf body) $
makeAccumulators wpro1 envPro $
body)
- (letBinds (efRebinds (IS IZ)) $
+ (letBinds (efRebinds IZ) $
weakenExpr
(autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
- ((#ftapebinds :++: #efPrerebinds) :++: #d :++: #ftape :++: #d2acPro :++: #d2acEnv)
+ ((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv)
.> wPro (subList (bindingsBinds ef0) subtapeEf))
(getSparseArg ef2))
}}
diff --git a/src/Compile.hs b/src/Compile.hs
index d6ad7ec..8627905 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -840,11 +840,14 @@ compile' env = \case
-- kvar <- if vecwid > 1 then genName' "k" else return ""
accvar <- genName' "tot"
+ pairvar <- genName' "pair" -- function input
+ (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun
+
let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++
({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]"
- (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun
((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit
+ pairstrname <- emitStruct (STPair t t)
emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $
pure (SVarDecl False (repSTy t) accvar (CELit x0name))
<> x0incrStmts -- we're copying x0 here
@@ -854,6 +857,7 @@ compile' env = \case
-- what comes out of the function anyway, so that's
-- fine, but we do need to increment the array element.
arreltIncrStmts
+ <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)]))
<> funStmts
<> pure (SAsg accvar funres))
<> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
@@ -997,12 +1001,14 @@ compile' env = \case
jvar <- genName' "j"
accvar <- genName' "tot"
+ pairvar <- genName' "pair" -- function input
+ (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun
let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar
arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]"
- (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun
funresvar <- genName' "res"
((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit
+ pairstrname <- emitStruct (STPair t t)
emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $
pure (SVarDecl False (repSTy t) accvar (CELit x0name))
<> x0incrStmts -- we're copying x0 here
@@ -1012,8 +1018,9 @@ compile' env = \case
-- what comes out of the function anyway, so that's
-- fine, but we do need to increment the array element.
arreltIncrStmts
+ <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)]))
<> funStmts
- <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres)
+ <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres)
<> pure (SAsg accvar (CEProj (CELit funresvar) "a"))
<> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b")))
<> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar))
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index d982261..e1c81cd 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -121,11 +121,11 @@ interpret'Rec env = \case
arrayMapM (\x -> interpret' (V t x `SCons` env) a) =<< interpret' env b
EFold1Inner _ _ a b c -> do
let t = typeOf b
- let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a
+ let f = \x -> interpret' (V (STPair t t) x `SCons` env) a
x0 <- interpret' env b
arr <- interpret' env c
let sh `ShCons` n = arrayShape arr
- arrayGenerateM sh $ \idx -> foldM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ arrayGenerateM sh $ \idx -> foldM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
ESum1Inner _ e -> do
arr <- interpret' env e
let STArr _ (STScal t) = typeOf e
@@ -162,14 +162,14 @@ interpret'Rec env = \case
return $ arrayGenerateLin sh (\i -> (arr1 `arrayIndexLinear` i, arr2 `arrayIndexLinear` i))
EFold1InnerD1 _ _ a b c -> do
let t = typeOf b
- let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a
+ let f = \x -> interpret' (V (STPair t t) x `SCons` env) a
x0 <- interpret' env b
arr <- interpret' env c
let sh `ShCons` n = arrayShape arr
-- TODO: this is very inefficient, even for an interpreter; with mutable
-- arrays this can be a lot better with no lists
res <- arrayGenerateM sh $ \idx -> do
- (y, stores) <- mapAccumLM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ (y, stores) <- mapAccumLM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
return (y, arrayFromList (ShNil `ShCons` n) stores)
return (arrayMap fst res
,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) ->
diff --git a/src/Language.hs b/src/Language.hs
index 31b4b87..c1a6248 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -1,6 +1,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE OverloadedLabels #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
@@ -15,6 +16,8 @@ module Language (
Lookup,
) where
+import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol)
+
import Array
import AST
import AST.Sparse.Types
@@ -113,7 +116,19 @@ map_ (v :-> a) b
NEDrop (SS SZ) (NEDrop (SS SZ) a)
fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
-fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3
+fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 =
+ withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) ->
+ assertSymbolNotUnderscore s3 $
+ equalityReflexive s3 $
+ assertSymbolDistinct s3 s1 $
+ let v3 = Var s3 (STPair t t)
+ in fold1i' (v3 :-> let_ v1 (fst_ (NEVar v3)) $
+ let_ v2 (snd_ (NEVar v3)) $
+ NEDrop (SS (SS SZ)) e1)
+ e2 e3
+
+fold1i' :: (Var name (TPair t t) :-> NExpr ('(name, TPair t t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
+fold1i' (v :-> e1) e2 e3 = NEFold1Inner v e1 e2 e3
sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
sum1i e = NESum1Inner e
@@ -135,7 +150,20 @@ reshape = NEReshape
fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b))
-> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
-fold1iD1 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD1 v1 v2 e1 e2 e3
+fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 =
+ withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) ->
+ assertSymbolNotUnderscore s3 $
+ equalityReflexive s3 $
+ assertSymbolDistinct s3 s1 $
+ let v3 = Var s3 (STPair t1 t1)
+ in fold1iD1' (v3 :-> let_ v1 (fst_ (NEVar v3)) $
+ let_ v2 (snd_ (NEVar v3)) $
+ NEDrop (SS (SS SZ)) e1)
+ e2 e3
+
+fold1iD1' :: (Var name (TPair t1 t1) :-> NExpr ('(name, TPair t1 t1) : env) (TPair t1 b))
+ -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b))
+fold1iD1' (v1 :-> e1) e2 e3 = NEFold1InnerD1 v1 e1 e2 e3
fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2))
-> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2))
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index c9d05c9..a3b8130 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -4,7 +4,9 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
@@ -17,7 +19,7 @@ module Language.AST where
import Data.Kind (Type)
import Data.Type.Equality
import GHC.OverloadedLabels
-import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..))
+import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal)
import Array
import AST
@@ -50,7 +52,7 @@ data NExpr env t where
-- array operations
NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t)
- NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
+ NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
NEUnit :: NExpr env t -> NExpr env (TArr Z t)
NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t)
@@ -58,7 +60,7 @@ data NExpr env t where
NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t)
- NEFold1InnerD1 :: Var n1 t1 -> Var n2 t1 -> NExpr ('(n2, t1) : '(n1, t1) : env) (TPair t1 b)
+ NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b)
-> NExpr env t1
-> NExpr env (TArr (S n) t1)
-> NExpr env (TPair (TArr n t1) (TArr (S n) b))
@@ -96,11 +98,16 @@ data NExpr env t where
NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t
deriving instance Show (NExpr env t)
-type family Lookup name env where
- Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'")
- Lookup name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope")
- Lookup name ('(name, t) : env) = t
- Lookup name (_ : env) = Lookup name env
+type Lookup name env = Lookup1 (name == "_") name env
+type family Lookup1 eqblank name env where
+ Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'")
+ Lookup1 False name env = Lookup2 name env
+type family Lookup2 name env where
+ Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope")
+ Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env
+type family Lookup3 eq t name env where
+ Lookup3 True t _ _ = t
+ Lookup3 False _ name env = Lookup2 name env
type family DropNth i env where
DropNth Z (_ : env) = env
@@ -209,7 +216,7 @@ fromNamedExpr val = \case
NEConstArr n t x -> EConstArr ext n t x
NEBuild k a n b -> EBuild ext k (go a) (lambda val n b)
- NEFold1Inner n1 n2 a b c -> EFold1Inner ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
+ NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c)
NESum1Inner e -> ESum1Inner ext (go e)
NEUnit e -> EUnit ext (go e)
NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)
@@ -217,7 +224,7 @@ fromNamedExpr val = \case
NEMinimum1Inner e -> EMinimum1Inner ext (go e)
NEReshape n a b -> EReshape ext n (go a) (go b)
- NEFold1InnerD1 n1 n2 a b c -> EFold1InnerD1 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
+ NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c)
NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c)
NEConst t x -> EConst ext t x
@@ -275,3 +282,17 @@ dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env
dropNthW SZ (_ `NPush` _) = WSink
dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val)
dropNthW _ NTop = error "DropNth: index out of range"
+
+assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r
+assertSymbolNotUnderscore s@SSymbol k =
+ case symbolVal s of
+ "_" -> error "assertSymbolNotUnderscore: was underscore"
+ _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k
+
+assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r
+assertSymbolDistinct s1@SSymbol s2@SSymbol k
+ | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")"
+ | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k
+
+equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r
+equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k