diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-26 23:05:30 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-26 23:05:30 +0100 |
commit | ade38c607a8d0dc8dc1d701084ed88df2fa89df9 (patch) | |
tree | 2183d63164a27fe84bb00b6c1920fe6c2be1e0e8 | |
parent | ae2b1b71a91d60d3bd1dfb21fce98c05c1a4fcbb (diff) |
Working argument accum mode (...)
The derivative of 'neural' in full accum mode is pretty atrocious now; I
think this is because when you have code like this:
\(a :: Arr 1 R) ->
let b = a
in let c = b
in sum d
then because the argument, as well as both let bindings, bind a value of
array type, each will introduce an accumulator, hence resulting in three
(!) nested `with` clauses that each just contribute their result back to
their parent. This is pointless, and we should fix this.
-rw-r--r-- | src/AST/Pretty.hs | 2 | ||||
-rw-r--r-- | src/Interpreter.hs | 25 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 2 | ||||
-rw-r--r-- | test/Main.hs | 8 |
4 files changed, 25 insertions, 12 deletions
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index ec8574f..663e9b0 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -218,7 +218,7 @@ ppExpr' d val = \case return $ showParen (d > 10) $ showString ("accum " ++ show (fromSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3' - EZero _ -> return $ showString "zero" + EZero t -> return $ showString ("(zero :: " ++ ppTy 0 t ++ ")") EPlus _ a b -> do a' <- ppExpr' 11 val a diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 56ebf82..bb4952c 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -30,7 +30,6 @@ import System.IO (hPutStrLn, stderr) import System.IO.Unsafe (unsafePerformIO) import Debug.Trace -import GHC.Stack import Array import AST @@ -250,7 +249,6 @@ onehotD2 _ STAccum{} _ _ = error "onehotD2: cannot index into accumulator" withAccum :: STy t -> STy a -> Rep t -> (RepAcSparse t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) withAccum t _ initval f = AcM $ do - putStrLn $ "withAccum: " ++ show t accum <- newAcSparse t SZ () initval out <- case f accum of AcM m -> m val <- readAcSparse t accum @@ -324,7 +322,7 @@ piindexConcat :: PartialInvIndex n m -> InvIndex m -> InvIndex n piindexConcat PIIxEnd ix = ix piindexConcat (PIIxCons i pix) ix = InvCons i (piindexConcat pix ix) -newAcSparse :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) +newAcSparse :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcSparse t) newAcSparse typ SZ () val = case typ of STNil -> return () STPair t1 t2 -> newIORef =<< (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) @@ -372,13 +370,19 @@ onehotArray (_ :: STy t) mkone mkzero dim dep@SS{} idx val = do go mk dep' dim' idx' val' $ \arr pish -> k arr (PIIxCons (fromIntegral @Int64 @Int i) pish) -newAcDense :: HasCallStack => STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) +newAcDense :: STy t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> IO (RepAcDense t) newAcDense typ SZ () val = case typ of + STPair t1 t2 -> (,) <$> newAcSparse t1 SZ () (fst val) <*> newAcSparse t2 SZ () (snd val) STEither t1 t2 -> case val of Left x -> Left <$> newAcSparse t1 SZ () x Right y -> Right <$> newAcSparse t2 SZ () y _ -> error "newAcDense: invalid dense type" newAcDense typ (SS dep) idx val = case typ of + STPair t1 t2 -> + case (idx, val) of + (Left idx', Left val') -> (,) <$> newAcSparse t1 dep idx' val' <*> newAcZero t2 + (Right idx', Right val') -> (,) <$> newAcZero t1 <*> newAcSparse t2 dep idx' val' + _ -> error "Index/value mismatch in newAc pair" STEither t1 t2 -> case (idx, val) of (Left idx', Left val') -> Left <$> newAcSparse t1 dep idx' val' @@ -400,6 +404,7 @@ readAcSparse typ val = case typ of readAcDense :: STy t -> RepAcDense t -> IO (Rep t) readAcDense typ val = case typ of + STPair t1 t2 -> (,) <$> readAcSparse t1 (fst val) <*> readAcSparse t2 (snd val) STEither t1 t2 -> case val of Left x -> Left <$> readAcSparse t1 x Right y -> Right <$> readAcSparse t2 y @@ -505,19 +510,27 @@ accumAddSparse typ (SS dep) ref idx val = case typ of accumAddDense :: forall t i s. STy t -> SNat i -> RepAcDense t -> Rep (AcIdx t i) -> Rep (AcVal t i) -> (RepAcDense t, AcM s ()) accumAddDense typ SZ ref () val = case typ of + STPair t1 t2 -> + (ref, do accumAddSparse t1 SZ (fst ref) () (fst val) + accumAddSparse t2 SZ (snd ref) () (snd val)) STEither t1 t2 -> case (ref, val) of (Left ref', Left val') -> (ref, accumAddSparse t1 SZ ref' () val') (Right ref', Right val') -> (ref, accumAddSparse t2 SZ ref' () val') - _ -> error "Mismatched Either in accumAdd" + _ -> error "Mismatched Either in accumAddDense either" _ -> error "accumAddDense: invalid dense type" accumAddDense typ (SS dep) ref idx val = case typ of + STPair t1 t2 -> + case (idx, val) of + (Left idx', Left val') -> (ref, accumAddSparse t1 dep (fst ref) idx' val') + (Right idx', Right val') -> (ref, accumAddSparse t2 dep (snd ref) idx' val') + _ -> error "Mismatched Either in accumAddDense pair" STEither t1 t2 -> case (ref, idx, val) of (Left ref', Left idx', Left val') -> (Left ref', accumAddSparse t1 dep ref' idx' val') (Right ref', Right idx', Right val') -> (Right ref', accumAddSparse t2 dep ref' idx' val') - _ -> error "Mismatched Either in accumAdd" + _ -> error "Mismatched Either in accumAddDense either" _ -> error "accumAddDense: invalid dense type" diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 0007991..335ad1f 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -37,7 +37,7 @@ type family RepAcSparse t where -- Immutable, and does not necessarily have a zero. type family RepAcDense t where RepAcDense TNil = () - -- RepAcDense (TPair a b) = (RepAcSparse a, RepAcSparse b) + RepAcDense (TPair a b) = (RepAcSparse a, RepAcSparse b) RepAcDense (TEither a b) = Either (RepAcSparse a) (RepAcSparse b) -- RepAcDense (TMaybe t) = RepAcSparse (TMaybe t) -- ^ This can be optimised to TMaybe (RepAcSparse t), but that makes accumAddDense very hard to write. And in any case, we don't need it because D2 will not produce Maybe of Maybe. -- RepAcDense (TArr n t) = Array n (RepAcSparse t) diff --git a/test/Main.hs b/test/Main.hs index d18884e..b6f9f2b 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -62,11 +62,11 @@ gradientByCHAD' = \simplIters env term input -> second (second (toTanE env input toTan typ primal der = case typ of STNil -> der STPair t1 t2 -> case der of - Left () -> bimap (zeroTan t1) (zeroTan t2) primal - Right (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal + Nothing -> bimap (zeroTan t1) (zeroTan t2) primal + Just (d₁, d₂) -> bimap (\p1 -> toTan t1 p1 d₁) (\p2 -> toTan t2 p2 d₂) primal STEither t1 t2 -> case der of - Left () -> bimap (zeroTan t1) (zeroTan t2) primal - Right d -> case (primal, d) of + Nothing -> bimap (zeroTan t1) (zeroTan t2) primal + Just d -> case (primal, d) of (Left p, Left d') -> Left (toTan t1 p d') (Right p, Right d') -> Right (toTan t2 p d') _ -> error "Primal and cotangent disagree on Either alternative" |