summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2025-03-17 23:19:44 +0100
committerTom Smeding <tom@tomsmeding.com>2025-03-17 23:19:44 +0100
commitdee165294d6b92b153a0b65e21f58f8073186d68 (patch)
tree0c9af35080dea594cd72d63c77a978a6a3616906
parent8ca9ceef96afffdc9d4bc266c978a6b4374131e6 (diff)
Compile EAccum
-rw-r--r--src/AST/Accum.hs1
-rw-r--r--src/Compile.hs70
2 files changed, 67 insertions, 4 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
index 6c46ad5..67c5de7 100644
--- a/src/AST/Accum.hs
+++ b/src/AST/Accum.hs
@@ -29,6 +29,7 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b
SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b
SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b
+ -- TODO: This SNat is rather useless, you always have an STy around too
SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b
-- TODO:
-- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
diff --git a/src/Compile.hs b/src/Compile.hs
index 5c9d1a2..d9cfd95 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -1,6 +1,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeApplications #-}
module Compile (compile) where
@@ -10,6 +11,7 @@ import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State.Strict
import Control.Monad.Trans.Writer.CPS
import Data.Bifunctor (first)
+import Data.Char (ord)
import Data.Foldable (toList)
import Data.Functor.Const
import qualified Data.Functor.Product as Product
@@ -22,6 +24,7 @@ import Data.Set (Set)
import Data.Some
import qualified Data.Vector as V
import Foreign
+import Numeric (showHex)
import System.IO (hPutStrLn, stderr)
import Prelude hiding ((^))
@@ -768,11 +771,71 @@ compile' env = \case
return $ CEStruct (repSTy (STPair (typeOf e2) t)) [("a", e2'), ("b", CELit accname)]
- -- EAccum _ n a b e -> error "TODO" -- EAccum n (compile' a) (compile' b) (compile' e)
+ EAccum _ t prj eidx eval eacc -> do
+ eidx' <- compile' env eidx
+ nameidx <- genName
+ emit $ SVarDecl True (repSTy (typeOf eidx)) nameidx eidx'
+
+ eval' <- compile' env eval
+ nameval <- genName
+ emit $ SVarDecl True (repSTy (typeOf eval)) nameval eval'
+
+ eacc' <- compile' env eacc
+ nameacc <- genName
+ emit $ SVarDecl True (repSTy (typeOf eacc)) nameacc eacc'
+
+ let accumRef :: STy a -> SAcPrj p a b -> String -> String -> String
+ accumRef _ SAPHere v _ = v
+ accumRef (STPair ta _) (SAPFst prj') v i = accumRef ta prj' (v++".a") i
+ accumRef (STPair _ tb) (SAPSnd prj') v i = accumRef tb prj' (v++".b") i
+ accumRef (STEither ta _) (SAPLeft prj') v i = accumRef ta prj' (v++".l") i
+ accumRef (STEither _ tb) (SAPRight prj') v i = accumRef tb prj' (v++".r") i
+ accumRef (STMaybe tj) (SAPJust prj') v i = accumRef tj prj' (v++".j") i
+ accumRef (STArr n t') (SAPArrIdx prj' _) v i =
+ accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b")
+
+ let add :: STy a -> String -> String -> CompM ()
+ add STNil _ _ = return ()
+ add (STPair t1 t2) d s = do
+ add t1 (d++".a") (s++".a")
+ add t2 (d++".b") (s++".b")
+ add (STEither t1 t2) d s = do
+ ((), stmts1) <- scope $ add t1 (d++".l") (s++".l")
+ ((), stmts2) <- scope $ add t2 (d++".r") (s++".r")
+ emit $ SAsg (d++".tag") (CELit (s++".tag"))
+ emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "0"))
+ (BList stmts1) (BList stmts2)
+ add (STMaybe t1) d s = do
+ ((), stmts1) <- scope $ add t1 (d++".j") (s++".j")
+ emit $ SAsg (d++".tag") (CELit (s++".tag"))
+ emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1"))
+ (BList stmts1) mempty
+ add (STArr n t1) d s = do
+ shsizename <- genName' "acshsz"
+ emit $ SVarDecl True "size_t" shsizename (compileShapeSize n (s++".a.b"))
+ ivar <- genName' "i"
+ ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]")
+ emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) $
+ BList stmts1
+ add (STScal sty) d s = case sty of
+ STI32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
+ STI64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
+ STF32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
+ STF64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";"
+ STBool -> error "Compile: accumulator add on booleans"
+ add (STAccum _) _ _ = error "Compile: nested accumulators unsupported"
+
+ let dest = accumRef t prj (nameacc++".ac") nameidx
+ add (typeOf eval) dest nameval
+
+ return $ CEStruct (repSTy STNil) []
EError _ t s -> do
- -- using 'show' here is wrong, but it's good enough for me.
- emit $ SVerbatim $ "fprintf(stderr, \"ERROR: %s\\n\", " ++ show s ++ "); exit(1);"
+ let padleft len c s' = replicate (len - length s) c ++ s'
+ escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c]
+ | ord c < 32 -> "\\x" ++ padleft 2 '0' (showHex (ord c) "")
+ | otherwise -> [c]
+ emit $ SVerbatim $ "fprintf(stderr, \"ERROR: %s\\n\", " ++ escape s ++ "); exit(1);"
case t of
STScal _ -> return (CELit "0")
_ -> do
@@ -785,7 +848,6 @@ compile' env = \case
EFold1Inner{} -> error "Compile: not implemented: EFold1Inner"
EIdx1{} -> error "Compile: not implemented: EIdx1"
- EAccum{} -> error "Compile: not implemented: EAccum"
data Increment = Increment | Decrement
deriving (Show)