summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-12-06 19:54:53 +0100
committerTom Smeding <tom@tomsmeding.com>2024-12-06 19:54:53 +0100
commit3e266262ebe65bd5d775711b4d05bc9670a38a47 (patch)
treebf0fff187e53adb8a4f45b3d7c70c97566c1e141
parent40a0abca1cedcdd930bb33d1874b7922443e5a8c (diff)
UnMonoid
-rw-r--r--chad-fast.cabal2
-rw-r--r--src/AST.hs27
-rw-r--r--src/AST/UnMonoid.hs118
-rw-r--r--test/Main.hs35
4 files changed, 155 insertions, 27 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index c3c2682..0887f17 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -75,7 +75,7 @@ test-suite test
transformers,
hs-source-dirs: test
default-language: Haskell2010
- ghc-options: -Wall -threaded
+ ghc-options: -Wall -threaded -rtsopts
benchmark bench
type: exitcode-stdio-1.0
diff --git a/src/AST.hs b/src/AST.hs
index 9ad0d4d..fff290a 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -396,17 +396,24 @@ emap f arr =
(EVar ext (tTup (sreplicate n tIx)) IZ)) $
weakenExpr (WCopy (WSink .> WSink)) f
-ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
-ezip a b =
- let STArr n t1 = typeOf a
- STArr _ t2 = typeOf b
- in ELet ext a $
- ELet ext (weakenExpr WSink b) $
+ezipWith :: Ex (b : a : env) c -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c)
+ezipWith f arr1 arr2 =
+ let STArr n t1 = typeOf arr1
+ STArr _ t2 = typeOf arr2
+ in ELet ext arr1 $
+ ELet ext (weakenExpr WSink arr2) $
EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $
- EPair ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) IZ))
- (EIdx ext (EVar ext (STArr n t2) (IS IZ))
- (EVar ext (tTup (sreplicate n tIx)) IZ))
+ ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ)))
+ (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $
+ weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f
+
+ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
+ezip arr1 arr2 =
+ let STArr _ t1 = typeOf arr1
+ STArr _ t2 = typeOf arr2
+ in ezipWith (EPair ext (EVar ext t1 (IS IZ)) (EVar ext t2 IZ)) arr1 arr2
eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a
eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c)
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 1675dab..8da1e32 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -13,7 +13,7 @@ unMonoid :: Ex env t -> Ex env t
unMonoid = \case
EZero t -> zero t
EPlus t a b -> plus t a b
- EOneHot t i a b -> _ t i a b
+ EOneHot t i a b -> onehot t i a b
EVar _ t i -> EVar ext t i
ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body)
@@ -51,7 +51,8 @@ zero STNil = ENil ext
zero (STPair t1 t2) = ENothing ext (STPair (d2 t1) (d2 t2))
zero (STEither t1 t2) = ENothing ext (STEither (d2 t1) (d2 t2))
zero (STMaybe t) = ENothing ext (d2 t)
-zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (zero t)
+zero (STArr SZ t) = EUnit ext (zero t)
+zero (STArr n t) = EBuild ext n (eTup (sreplicate n (EConst ext STI64 0))) (EError (d2 t) "empty")
zero (STScal t) = case t of
STI32 -> ENil ext
STI64 -> ENil ext
@@ -85,7 +86,13 @@ plus (STMaybe t) a b =
plus (STArr n t) a b =
ELet ext a $
ELet ext (weakenExpr WSink b) $
- ECase
+ eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) (IS IZ))))
+ (EVar ext (STArr n (d2 t)) IZ)
+ (eif (eshapeEmpty n (EShape ext (EVar ext (STArr n (d2 t)) IZ)))
+ (EVar ext (STArr n (d2 t)) (IS IZ))
+ (ezipWith (plus t (EVar ext (d2 t) (IS IZ)) (EVar ext (d2 t) IZ))
+ (EVar ext (STArr n (d2 t)) (IS IZ))
+ (EVar ext (STArr n (d2 t)) IZ)))
plus (STScal t) a b = case t of
STI32 -> ENil ext
STI64 -> ENil ext
@@ -108,3 +115,108 @@ plusSparse t a b adder =
(weakenExpr (WCopy (WCopy WSink)) adder)
(EVar ext (STMaybe t) (IS IZ))))
(weakenExpr WSink a)
+
+onehot :: STy t -> SNat i -> Ex env (AcIdx (D2 t) i) -> Ex env (AcVal (D2 t) i) -> Ex env (D2 t)
+onehot _ SZ _ val = val
+onehot t (SS dep) idx val = case t of
+ STPair t1 t2 ->
+ case dep of
+ SZ -> EJust ext val
+ SS dep' ->
+ let STEither tidx1 tidx2 = typeOf idx
+ STEither tval1 tval2 = typeOf val
+ in EJust ext $
+ ECase ext idx
+ (ECase ext (weakenExpr WSink val)
+ (EPair ext (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ))
+ (zero t2))
+ (EError (STPair (d2 t1) (d2 t2)) "onehot pair l/r"))
+ (ECase ext (weakenExpr WSink val)
+ (EError (STPair (d2 t1) (d2 t2)) "onehot pair r/l")
+ (EPair ext (zero t1)
+ (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ))))
+
+ STEither t1 t2 ->
+ case dep of
+ SZ -> EJust ext val
+ SS dep' ->
+ let STEither tidx1 tidx2 = typeOf idx
+ STEither tval1 tval2 = typeOf val
+ in EJust ext $
+ ECase ext idx
+ (ECase ext (weakenExpr WSink val)
+ (EInl ext (d2 t2) (onehot t1 dep' (EVar ext tidx1 (IS IZ)) (EVar ext tval1 IZ)))
+ (EError (STEither (d2 t1) (d2 t2)) "onehot either l/r"))
+ (ECase ext (weakenExpr WSink val)
+ (EError (STEither (d2 t1) (d2 t2)) "onehot either r/l")
+ (EInr ext (d2 t1) (onehot t2 dep' (EVar ext tidx2 (IS IZ)) (EVar ext tval2 IZ))))
+
+ STMaybe t1 -> EJust ext (onehot t1 dep idx val)
+
+ STArr n t1 ->
+ ELet ext val $
+ EBuild ext n (EFst ext (EVar ext (typeOf val) IZ))
+ (onehotArrayElem t1 n (SS dep)
+ (EVar ext (tTup (sreplicate n tIx)) IZ)
+ (weakenExpr (WSink .> WSink) idx)
+ (ESnd ext (EVar ext (typeOf val) (IS IZ))))
+
+ STNil -> error "Cannot index into nil"
+ STScal{} -> error "Cannot index into scalar"
+ STAccum{} -> error "Accumulators not allowed in input program"
+
+onehotArrayElem
+ :: STy t -> SNat n -> SNat i
+ -> Ex env (Tup (Replicate n TIx)) -- ^ where are we now, OUTSIDE-IN as produced by 'outsideInIndex'
+ -> Ex env (AcIdx (TArr n (D2 t)) i) -- ^ where to put the one-hot
+ -> Ex env (AcValArr n (D2 t) i) -- ^ value to put in the hole
+ -> Ex env (D2 t)
+onehotArrayElem t n dep eltidx idx val =
+ ELet ext eltidx $
+ ELet ext (weakenExpr WSink idx) $
+ let (cond, elt) = onehotArrayElemRec t n dep
+ (outsideInIndex n (EVar ext (typeOf eltidx) (IS IZ)))
+ (EVar ext (typeOf idx) IZ)
+ (weakenExpr (WSink .> WSink) val)
+ in eif cond elt (zero t)
+
+-- AcIdx must be duplicable
+onehotArrayElemRec
+ :: STy t -> SNat n -> SNat i
+ -> [Ex env TIx]
+ -> Ex env (AcIdx (TArr n (D2 t)) i)
+ -> Ex env (AcValArr n (D2 t) i)
+ -> (Ex env (TScal TBool), Ex env (D2 t))
+onehotArrayElemRec _ n SZ eltidx _ val =
+ (EConst ext STBool True
+ ,EIdx ext val (reconstructFromOutsideIn n eltidx))
+onehotArrayElemRec t SZ (SS dep) eltidx idx val =
+ case eltidx of
+ [] -> (EConst ext STBool True, onehot t dep idx val)
+ _ -> error "onehotArrayElemRec: mismatched list length"
+onehotArrayElemRec t (SS n) (SS dep) eltidx idx val =
+ case eltidx of
+ i : eltidx' ->
+ let (cond, elt) = onehotArrayElemRec t n dep eltidx' (ESnd ext idx) val
+ in (EOp ext OAnd (EPair ext (EOp ext (OEq STI64) (EPair ext i (EFst ext idx))) cond)
+ ,elt)
+ [] -> error "onehotArrayElemRec: mismatched list length"
+
+-- | Outermost index at the head. The input expression must be duplicable.
+outsideInIndex :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx]
+outsideInIndex = \n idx -> go n idx []
+ where
+ go :: SNat n -> Ex env (Tup (Replicate n TIx)) -> [Ex env TIx] -> [Ex env TIx]
+ go SZ _ acc = acc
+ go (SS n) idx acc = go n (EFst ext idx) (ESnd ext idx : acc)
+
+-- Takes a list with the outermost index at the head. Returns a tuple with the
+-- innermost index on the right.
+reconstructFromOutsideIn :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx))
+reconstructFromOutsideIn = \n list -> go n (reverse list)
+ where
+ -- Takes list with the _innermost_ index at the head.
+ go :: SNat n -> [Ex env TIx] -> Ex env (Tup (Replicate n TIx))
+ go SZ [] = ENil ext
+ go (SS n) (i:is) = EPair ext (go n is) i
+ go _ _ = error "reconstructFromOutsideIn: mismatched list length"
diff --git a/test/Main.hs b/test/Main.hs
index b6f9f2b..5db7ea0 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -24,6 +24,7 @@ import Hedgehog.Main
import Array
import AST
import AST.Pretty
+import AST.UnMonoid
import CHAD.Top
import CHAD.Types
import qualified Example
@@ -274,19 +275,9 @@ tests = checkParallel $ Group "AD"
let_ #a (build1 (snd_ (shape #x)) (#i :-> #x ! pair nil #i)) $
42)
- ,("neural", adTestGen Example.neural $ do
- let tR = STScal STF64
- let genLayer nin nout =
- liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
- <*> genArray tR (ShNil `ShCons` nout)
- nin <- Gen.integral (Range.linear 1 10)
- n1 <- Gen.integral (Range.linear 1 10)
- n2 <- Gen.integral (Range.linear 1 10)
- input <- genArray tR (ShNil `ShCons` nin)
- lay1 <- genLayer nin n1
- lay2 <- genLayer n1 n2
- lay3 <- genArray tR (ShNil `ShCons` n2)
- return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil))
+ ,("neural", adTestGen Example.neural genNeural)
+
+ ,("neural-unMonoid", adTestGen (unMonoid (simplifyFix Example.neural)) genNeural)
,("logsumexp", adTestTp (C "" 1) $
fromNamed $ lambda @(TArr N1 _) #vec $ body $
@@ -304,7 +295,11 @@ tests = checkParallel $ Group "AD"
,("gmm-wrong", withShrinks 0 $ adTestGen (Example.gmmObjective True) genGMM)
+ ,("gmm-wrong-unMonoid", withShrinks 0 $ adTestGen (unMonoid (simplifyFix (Example.gmmObjective True))) genGMM)
+
,("gmm", withShrinks 0 $ adTestGen (Example.gmmObjective False) genGMM)
+
+ ,("gmm-unMonoid", withShrinks 0 $ adTestGen (unMonoid (simplifyFix (Example.gmmObjective False))) genGMM)
]
where
genGMM = do
@@ -330,5 +325,19 @@ tests = checkParallel $ Group "AD"
Value (i2i64 kK) `SCons` Value (i2i64 kD) `SCons` Value (i2i64 kN) `SCons`
SNil)
+ genNeural = do
+ let tR = STScal STF64
+ let genLayer nin nout =
+ liftV2 (,) <$> genArray tR (ShNil `ShCons` nout `ShCons` nin)
+ <*> genArray tR (ShNil `ShCons` nout)
+ nin <- Gen.integral (Range.linear 1 10)
+ n1 <- Gen.integral (Range.linear 1 10)
+ n2 <- Gen.integral (Range.linear 1 10)
+ input <- genArray tR (ShNil `ShCons` nin)
+ lay1 <- genLayer nin n1
+ lay2 <- genLayer n1 n2
+ lay3 <- genArray tR (ShNil `ShCons` n2)
+ return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)
+
main :: IO ()
main = defaultMain [tests]