diff options
author | Tom Smeding <tom@tomsmeding.com> | 2025-03-17 23:19:44 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2025-03-17 23:19:44 +0100 |
commit | dee165294d6b92b153a0b65e21f58f8073186d68 (patch) | |
tree | 0c9af35080dea594cd72d63c77a978a6a3616906 | |
parent | 8ca9ceef96afffdc9d4bc266c978a6b4374131e6 (diff) |
Compile EAccum
-rw-r--r-- | src/AST/Accum.hs | 1 | ||||
-rw-r--r-- | src/Compile.hs | 70 |
2 files changed, 67 insertions, 4 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 6c46ad5..67c5de7 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -29,6 +29,7 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b + -- TODO: This SNat is rather useless, you always have an STy around too SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b -- TODO: -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) diff --git a/src/Compile.hs b/src/Compile.hs index 5c9d1a2..d9cfd95 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TypeApplications #-} module Compile (compile) where @@ -10,6 +11,7 @@ import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State.Strict import Control.Monad.Trans.Writer.CPS import Data.Bifunctor (first) +import Data.Char (ord) import Data.Foldable (toList) import Data.Functor.Const import qualified Data.Functor.Product as Product @@ -22,6 +24,7 @@ import Data.Set (Set) import Data.Some import qualified Data.Vector as V import Foreign +import Numeric (showHex) import System.IO (hPutStrLn, stderr) import Prelude hiding ((^)) @@ -768,11 +771,71 @@ compile' env = \case return $ CEStruct (repSTy (STPair (typeOf e2) t)) [("a", e2'), ("b", CELit accname)] - -- EAccum _ n a b e -> error "TODO" -- EAccum n (compile' a) (compile' b) (compile' e) + EAccum _ t prj eidx eval eacc -> do + eidx' <- compile' env eidx + nameidx <- genName + emit $ SVarDecl True (repSTy (typeOf eidx)) nameidx eidx' + + eval' <- compile' env eval + nameval <- genName + emit $ SVarDecl True (repSTy (typeOf eval)) nameval eval' + + eacc' <- compile' env eacc + nameacc <- genName + emit $ SVarDecl True (repSTy (typeOf eacc)) nameacc eacc' + + let accumRef :: STy a -> SAcPrj p a b -> String -> String -> String + accumRef _ SAPHere v _ = v + accumRef (STPair ta _) (SAPFst prj') v i = accumRef ta prj' (v++".a") i + accumRef (STPair _ tb) (SAPSnd prj') v i = accumRef tb prj' (v++".b") i + accumRef (STEither ta _) (SAPLeft prj') v i = accumRef ta prj' (v++".l") i + accumRef (STEither _ tb) (SAPRight prj') v i = accumRef tb prj' (v++".r") i + accumRef (STMaybe tj) (SAPJust prj') v i = accumRef tj prj' (v++".j") i + accumRef (STArr n t') (SAPArrIdx prj' _) v i = + accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") + + let add :: STy a -> String -> String -> CompM () + add STNil _ _ = return () + add (STPair t1 t2) d s = do + add t1 (d++".a") (s++".a") + add t2 (d++".b") (s++".b") + add (STEither t1 t2) d s = do + ((), stmts1) <- scope $ add t1 (d++".l") (s++".l") + ((), stmts2) <- scope $ add t2 (d++".r") (s++".r") + emit $ SAsg (d++".tag") (CELit (s++".tag")) + emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "0")) + (BList stmts1) (BList stmts2) + add (STMaybe t1) d s = do + ((), stmts1) <- scope $ add t1 (d++".j") (s++".j") + emit $ SAsg (d++".tag") (CELit (s++".tag")) + emit $ SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) + (BList stmts1) mempty + add (STArr n t1) d s = do + shsizename <- genName' "acshsz" + emit $ SVarDecl True "size_t" shsizename (compileShapeSize n (s++".a.b")) + ivar <- genName' "i" + ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]") + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) $ + BList stmts1 + add (STScal sty) d s = case sty of + STI32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" + STI64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" + STF32 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" + STF64 -> emit $ SVerbatim $ d ++ " += " ++ s ++ ";" + STBool -> error "Compile: accumulator add on booleans" + add (STAccum _) _ _ = error "Compile: nested accumulators unsupported" + + let dest = accumRef t prj (nameacc++".ac") nameidx + add (typeOf eval) dest nameval + + return $ CEStruct (repSTy STNil) [] EError _ t s -> do - -- using 'show' here is wrong, but it's good enough for me. - emit $ SVerbatim $ "fprintf(stderr, \"ERROR: %s\\n\", " ++ show s ++ "); exit(1);" + let padleft len c s' = replicate (len - length s) c ++ s' + escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c] + | ord c < 32 -> "\\x" ++ padleft 2 '0' (showHex (ord c) "") + | otherwise -> [c] + emit $ SVerbatim $ "fprintf(stderr, \"ERROR: %s\\n\", " ++ escape s ++ "); exit(1);" case t of STScal _ -> return (CELit "0") _ -> do @@ -785,7 +848,6 @@ compile' env = \case EFold1Inner{} -> error "Compile: not implemented: EFold1Inner" EIdx1{} -> error "Compile: not implemented: EIdx1" - EAccum{} -> error "Compile: not implemented: EAccum" data Increment = Increment | Decrement deriving (Show) |