diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-12-06 19:54:53 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-12-06 19:54:53 +0100 |
commit | 3e266262ebe65bd5d775711b4d05bc9670a38a47 (patch) | |
tree | bf0fff187e53adb8a4f45b3d7c70c97566c1e141 | |
parent | 40a0abca1cedcdd930bb33d1874b7922443e5a8c (diff) |
UnMonoid
-rw-r--r-- | chad-fast.cabal | 2 | ||||
-rw-r--r-- | src/AST.hs | 27 | ||||
-rw-r--r-- | src/AST/UnMonoid.hs | 118 | ||||
-rw-r--r-- | test/Main.hs | 35 |
4 files changed, 155 insertions, 27 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index c3c2682..0887f17 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -75,7 +75,7 @@ test-suite test transformers, hs-source-dirs: test default-language: Haskell2010 - ghc-options: -Wall -threaded + ghc-options: -Wall -threaded -rtsopts benchmark bench type: exitcode-stdio-1.0 @@ -396,17 +396,24 @@ emap f arr = (EVar ext (tTup (sreplicate n tIx)) IZ)) $ weakenExpr (WCopy (WSink .> WSink)) f -ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) -ezip a b = - let STArr n t1 = typeOf a - STArr _ t2 = typeOf b - in ELet ext a $ - ELet ext (weakenExpr WSink b) $ +ezipWith :: Ex (b : a : env) c -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) +ezipWith f arr1 arr2 = + let STArr n t1 = typeOf arr1 + STArr _ t2 = typeOf arr2 + in ELet ext arr1 $ + ELet ext (weakenExpr WSink arr2) $ EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ - EPair ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) - (EIdx ext (EVar ext (STArr n t2) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) + ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ + weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f + +ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) +ezip arr1 arr2 = + let STArr _ t1 = typeOf arr1 + STArr _ t2 = typeOf arr2 + in ezipWith (EPair ext (EVar ext t1 (IS IZ)) (EVar ext t2 IZ)) arr1 arr2 eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index 1675dab..8da1e32 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -13,7 +13,7 @@ unMonoid :: Ex env t -> Ex env t unMonoid = \case EZero t -> zero t EPlus t a b -> plus t a b - EOneHot t i a b -> _ t i a b + EOneHot t i a b -> onehot t i a b EVar _ t i -> EVar ext t i ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) @@ -51,7 +51,8 @@ zero STNil = ENil ext zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2)) zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2)) zero (STMaybe t) = ENothing ext (d2 t) -zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t) +zero (STArr SZ t) = EUnit ext (zero t) +zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError (d2 t) "empty") zero (STScal t) = case t of STI32 -> ENil ext STI64 -> ENil ext @@ -85,7 +86,13 @@ plus (STMaybe t) a b = plus (STArr n t) a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ - ECase + eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ)))) + (EVar ext (STArr n (d2 t)) IZ) + (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ))) + (EVar ext (STArr n (d2 t)) (IS IZ)) + (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ)) + (EVar ext (STArr n (d2 t)) (IS IZ)) + (EVar ext (STArr n (d2 t)) IZ))) plus (STScal t) a b = case t of STI32 -> ENil ext STI64 -> ENil ext @@ -108,3 +115,108 @@ plusSparse t a b adder = (weakenExpr (WCopy (WCopy WSink)) adder) (EVar ext (STMaybe t) (IS IZ)))) (weakenExpr WSink a) + +onehot :: STy t -> SNat i -> Ex env (AcIdx (D2 t) i) -> Ex env (AcVal (D2 t) i) -> Ex env (D2 t) +onehot _ SZ _ val = val +onehot t (SS dep) idx val = case t of + STPair t1 t2 -> + case dep of + SZ -> EJust ext val + SS dep' -> + let STEither tidx1 tidx2 = typeOf idx + STEither tval1 tval2 = typeOf val + in EJust ext $ + ECase ext idx + (ECase ext (weakenExpr WSink val) + (EPair ext (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ)) + (zero t2)) + (EError (STPair (d2 t1) (d2 t2)) "onehot pair l/r")) + (ECase ext (weakenExpr WSink val) + (EError (STPair (d2 t1) (d2 t2)) "onehot pair r/l") + (EPair ext (zero t1) + (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) + + STEither t1 t2 -> + case dep of + SZ -> EJust ext val + SS dep' -> + let STEither tidx1 tidx2 = typeOf idx + STEither tval1 tval2 = typeOf val + in EJust ext $ + ECase ext idx + (ECase ext (weakenExpr WSink val) + (EInl ext (d2 t2) (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ))) + (EError (STEither (d2 t1) (d2 t2)) "onehot either l/r")) + (ECase ext (weakenExpr WSink val) + (EError (STEither (d2 t1) (d2 t2)) "onehot either r/l") + (EInr ext (d2 t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ)))) + + STMaybe t1 -> EJust ext (onehot t1 dep idx val) + + STArr n t1 -> + ELet ext val $ + EBuild ext n (EFst ext (EVar ext (typeOf val) IZ)) + (onehotArrayElem t1 n (SS dep) + (EVar ext (tTup (sreplicate n tIx)) IZ) + (weakenExpr (WSink .> WSink) idx) + (ESnd ext (EVar ext (typeOf val) (IS IZ)))) + + STNil -> error "Cannot index into nil" + STScal{} -> error "Cannot index into scalar" + STAccum{} -> error "Accumulators not allowed in input program" + +onehotArrayElem + :: STy t -> SNat n -> SNat i + -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex' + -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot + -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole + -> Ex env (D2 t) +onehotArrayElem t n dep eltidx idx val = + ELet ext eltidx $ + ELet ext (weakenExpr WSink idx) $ + let (cond, elt) = onehotArrayElemRec t n dep + (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ))) + (EVar ext (typeOf idx) IZ) + (weakenExpr (WSink .> WSink) val) + in eif cond elt (zero t) + +-- AcIdx must be duplicable +onehotArrayElemRec + :: STy t -> SNat n -> SNat i + -> [Ex env TIx] + -> Ex env (AcIdx (TArr n (D2 t)) i) + -> Ex env (AcValArr n (D2 t) i) + -> (Ex env (TScal TBool), Ex env (D2 t)) +onehotArrayElemRec _ n SZ eltidx _ val = + (EConst ext STBool True + ,EIdx ext val (reconstructFromOutsideIn n eltidx)) +onehotArrayElemRec t SZ (SS dep) eltidx idx val = + case eltidx of + [] -> (EConst ext STBool True, onehot t dep idx val) + _ -> error "onehotArrayElemRec: mismatched list length" +onehotArrayElemRec t (SS n) (SS dep) eltidx idx val = + case eltidx of + i : eltidx' -> + let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val + in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond) + ,elt) + [] -> error "onehotArrayElemRec: mismatched list length" + +-- | Outermost index at the head. The input expression must be duplicable. +outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] +outsideInIndex = \n idx -> go n idx [] + where + go :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -> [Ex env TIx] + go SZ _ acc = acc + go (SS n) idx acc = go n (EFst ext idx) (ESnd ext idx : acc) + +-- Takes a list with the outermost index at the head. Returns a tuple with the +-- innermost index on the right. +reconstructFromOutsideIn :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx)) +reconstructFromOutsideIn = \n list -> go n (reverse list) + where + -- Takes list with the _innermost_ index at the head. + go :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx)) + go SZ [] = ENil ext + go (SS n) (i:is) = EPair ext (go n is) i + go _ _ = error "reconstructFromOutsideIn: mismatched list length" diff --git a/test/Main.hs b/test/Main.hs index b6f9f2b..5db7ea0 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -24,6 +24,7 @@ import Hedgehog.Main import Array import AST import AST.Pretty +import AST.UnMonoid import CHAD.Top import CHAD.Types import qualified Example @@ -274,19 +275,9 @@ tests = checkParallel $ Group "AD" let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $ 42) - ,("neural", adTestGen Example.neural $ do - let tR = STScal STF64 - let genLayer nin nout = - liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) - <*> genArray tR (ShNil `ShCons` nout) - nin <- Gen.integral (Range.linear 1 10) - n1 <- Gen.integral (Range.linear 1 10) - n2 <- Gen.integral (Range.linear 1 10) - input <- genArray tR (ShNil `ShCons` nin) - lay1 <- genLayer nin n1 - lay2 <- genLayer n1 n2 - lay3 <- genArray tR (ShNil `ShCons` n2) - return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)) + ,("neural", adTestGen Example.neural genNeural) + + ,("neural-unMonoid", adTestGen (unMonoid (simplifyFix Example.neural)) genNeural) ,("logsumexp", adTestTp (C "" 1) $ fromNamed $ lambda @(TArr N1 _) #vec $ body $ @@ -304,7 +295,11 @@ tests = checkParallel $ Group "AD" ,("gmm-wrong", withShrinks 0 $ adTestGen (Example.gmmObjective True) genGMM) + ,("gmm-wrong-unMonoid", withShrinks 0 $ adTestGen (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM) + ,("gmm", withShrinks 0 $ adTestGen (Example.gmmObjective False) genGMM) + + ,("gmm-unMonoid", withShrinks 0 $ adTestGen (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM) ] where genGMM = do @@ -330,5 +325,19 @@ tests = checkParallel $ Group "AD" Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons` SNil) + genNeural = do + let tR = STScal STF64 + let genLayer nin nout = + liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin) + <*> genArray tR (ShNil `ShCons` nout) + nin <- Gen.integral (Range.linear 1 10) + n1 <- Gen.integral (Range.linear 1 10) + n2 <- Gen.integral (Range.linear 1 10) + input <- genArray tR (ShNil `ShCons` nin) + lay1 <- genLayer nin n1 + lay2 <- genLayer n1 n2 + lay3 <- genArray tR (ShNil `ShCons` n2) + return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil) + main :: IO () main = defaultMain [tests] |