diff options
Diffstat (limited to 'src/Compile.hs')
| -rw-r--r-- | src/Compile.hs | 104 | 
1 files changed, 81 insertions, 23 deletions
| diff --git a/src/Compile.hs b/src/Compile.hs index 8c77cd6..4c07f3a 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -13,7 +13,7 @@ import Data.Foldable (toList)  import Data.Functor.Const  import qualified Data.Functor.Product as Product  import Data.Functor.Product (Product) -import Data.List (intersperse, intercalate) +import Data.List (foldl1', intersperse, intercalate)  import qualified Data.Map.Strict as Map  import qualified Data.Set as Set  import Data.Set (Set) @@ -90,6 +90,7 @@ data CExpr    = CELit String  -- ^ inserted as-is, assumed no parentheses needed    | CEStruct String [(String, CExpr)]  -- ^ struct construction literal: `(name){.field=expr}`    | CEProj CExpr String  -- ^ field projection: expr.field +  | CEPtrProj CExpr String  -- ^ field projection through pointer: expr->field    | CEAddrOf CExpr  -- ^ &expr    | CEIndex CExpr CExpr  -- ^ expr[expr]    | CECall String [CExpr]  -- ^ function(arg1, ..., argn) @@ -143,6 +144,7 @@ printCExpr d = \case                                                 | (n, e) <- pairs])        . showString "}"    CEProj e name -> printCExpr 99 e . showString ("." ++ name) +  CEPtrProj e name -> printCExpr 99 e . showString ("->" ++ name)    CEAddrOf e -> showParen (d > 80) $ showString "&" . printCExpr 80 e    CEIndex e1 e2 -> printCExpr 99 e1 . showString "[" . printCExpr 0 e2 . showString "]"    CECall n es -> @@ -594,26 +596,12 @@ compile' env = \case      return (CEStruct strname [("buf", CEAddrOf (CELit tldname))])    EBuild _ n esh efun -> do -    let arrty = STArr n (typeOf efun) -    strname <- emitStruct arrty -      shname <- genName' "sh"      emit . SVarDecl True (repSTy (typeOf esh)) shname =<< compile' env esh      shsizename <- genName' "shsz"      emit $ SVarDecl True "size_t" shsizename (compileShapeSize n shname) -    bufpname <- genName' "bufp" -    emit $ SVarDecl True (repSTy arrty ++ "_buf *") bufpname -             (CECall "malloc" [CEBinop (CELit (show (fromSNat n * 8 + 8))) -                                       "+" -                                       (CEBinop (CELit shsizename) -                                                "*" (CELit (show (sizeofSTy (typeOf efun)))))]) -    forM_ (zip (compileShapeTupIntoArray n shname) [0::Int ..]) $ \(rhs, i) -> -      emit $ SAsg (bufpname ++ "->sh[" ++ show i ++ "]") rhs -    emit $ SAsg (bufpname ++ "->refc") (CELit "1") - -    arrname <- genName' "arr" -    emit $ SVarDecl True (repSTy arrty) arrname (CEStruct strname [("buf", CELit bufpname)]) +    arrname <- allocArray "arr" n (typeOf efun) (CELit shsizename) (compileShapeTupIntoArray n shname)      idxargname <- genName' "ix"      (funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun @@ -623,20 +611,57 @@ compile' env = \case      emit $ SBlock $        pure (SVarDecl False "size_t" linivar (CELit "0"))        <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0") -                               (CECast (repSTy tIx) (CEIndex (CELit (bufpname ++ "->sh")) (CELit (show dimidx)))) +                               (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".buf->sh")) (CELit (show dimidx))))                   | (ivar, dimidx) <- zip ivars [0::Int ..]]             (pure (SVarDecl True (repSTy (typeOf esh)) idxargname                             (shapeTupFromLitVars n ivars))              <> BList funstmts -            <> pure (SAsg (bufpname ++ "->a[" ++ linivar ++ "++]") funretval)) +            <> pure (SAsg (arrname ++ ".buf->a[" ++ linivar ++ "++]") funretval))      return (CELit arrname)    -- EFold1Inner _ a b c -> error "TODO" -- EFold1Inner ext (compile' a) (compile' b) (compile' c) -  -- ESum1Inner _ e -> error "TODO" -- ESum1Inner ext (compile' e) +  ESum1Inner _ e -> do +    let STArr (SS n) t = typeOf e +    e' <- compile' env e +    argname <- genName' "sumarg" +    emit $ SVarDecl True (repSTy (STArr (SS n) t)) argname e' + +    shszname <- genName' "shsz" +    -- This n is one less than the shape of the thing we're querying, which is +    -- unexpected. But it's exactly what we want, so we do it anyway. +    emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) -  -- EUnit _ e -> error "TODO" -- EUnit ext (compile' e) +    resname <- allocArray "sumres" n t (CELit shszname) +                  [CELit (argname ++ ".buf->sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + +    lenname <- genName' "n" +    emit $ SVarDecl True (repSTy tIx) lenname +                    (CELit (argname ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")) + +    ivar <- genName' "i" +    jvar <- genName' "j" +    accvar <- genName' "tot" +    emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList +             -- we have ScalIsNumeric, so it has 0 and (+) in C +             [SVarDecl False (repSTy t) accvar (CELit "0") +             ,SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ +                pure $ SVerbatim $ accvar ++ " += " ++ argname ++ ".buf->a[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "];" +             ,SAsg (resname ++ ".buf->a[" ++ ivar ++ "]") (CELit accvar)] + +    return (CELit resname) + +  EUnit _ e -> do +    e' <- compile' env e +    let typ = STArr SZ (typeOf e) +    strname <- emitStruct typ +    name <- genName +    emit $ SVarDecl True strname name (CEStruct strname +                  [("buf", CECall "malloc" [CELit (show (8 + sizeofSTy (typeOf e)))])]) +    emit $ SAsg (name ++ ".buf->refc") (CELit "1") +    emit $ SAsg (name ++ ".buf->a[0]") e' +    return (CELit name)    -- EReplicate1Inner _ a b -> error "TODO" -- EReplicate1Inner ext (compile' a) (compile' b) @@ -646,7 +671,16 @@ compile' env = \case    EConst _ t x -> return $ CELit $ compileScal True t x -  -- EIdx0 _ e -> error "TODO" -- EIdx0 ext (compile' e) +  EIdx0 _ e -> do +    let STArr _ t = typeOf e +    e' <- compile' env e +    arrname <- genName +    emit $ SVarDecl True (repSTy (STArr SZ t)) arrname e' +    name <- genName +    emit $ SVarDecl True (repSTy t) name +              (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "a") (CELit "0")) +    incrementVarAlways Decrement (STArr SZ t) arrname +    return (CELit name)    -- EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b) @@ -655,7 +689,7 @@ compile' env = \case      arrname <- genName' "ixarr"      idxname <- genName' "ixix"      emit . SVarDecl True (repSTy (typeOf earr)) arrname =<< compile' env earr -    emit . SVarDecl True (repSTy (typeOf eidx)) idxname =<< compile' env eidx +    when (fromSNat n > 0) $ emit . SVarDecl True (repSTy (typeOf eidx)) idxname =<< compile' env eidx      resname <- genName' "ixres"      emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->a")) (toLinearIdx n arrname idxname))      incrementVarAlways Decrement (STArr n t) arrname @@ -771,6 +805,23 @@ toLinearIdx (SS n) arrvar idxvar =  --   emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".buf->sh[" ++ show (fromSNat n) ++ "]")))  --   _ +-- | The shape must have the outer dimension at the head (and the inner dimension on the right). +allocArray :: String -> SNat n -> STy t -> CExpr -> [CExpr] -> CompM String +allocArray nameBase rank eltty shsz shape = do +  when (length shape /= fromSNat rank) $ +    error "allocArray: shape does not match rank" +  let arrty = STArr rank eltty +  strname <- emitStruct arrty +  arrname <- genName' nameBase +  emit $ SVarDecl True strname arrname $ CEStruct strname +            [("buf", CECall "malloc" [CEBinop (CELit (show (fromSNat rank * 8 + 8))) +                                              "+" +                                              (CEBinop shsz "*" (CELit (show (sizeofSTy eltty))))])] +  forM_ (zip shape [0::Int ..]) $ \(dim, i) -> +    emit $ SAsg (arrname ++ ".buf->sh[" ++ show i ++ "]") dim +  emit $ SAsg (arrname ++ ".buf->refc") (CELit "1") +  return arrname +  compileShapeQuery :: SNat n -> String -> CExpr  compileShapeQuery SZ _ = CEStruct (repSTy STNil) []  compileShapeQuery (SS n) var = @@ -779,10 +830,17 @@ compileShapeQuery (SS n) var =      ,("b", CEIndex (CELit (var ++ ".buf->sh")) (CELit (show (fromSNat n))))]  compileShapeSize :: SNat n -> String -> CExpr -compileShapeSize SZ _ = CELit "0" +compileShapeSize SZ _ = CELit "1"  compileShapeSize (SS SZ) var = CELit (var ++ ".b")  compileShapeSize (SS n) var = CEBinop (compileShapeSize n (var ++ ".a")) "*" (CELit (var ++ ".b")) +-- | Takes a variable name for the array, not the buffer. +compileArrShapeSize :: SNat n -> String -> CExpr +compileArrShapeSize SZ _ = CELit "1" +compileArrShapeSize n var = +  foldl1' (\a b -> CEBinop a "*" b) [CELit (var ++ ".buf->sh[" ++ show i ++ "]") +                                    | i <- [0 .. fromSNat n - 1]] +  compileShapeTupIntoArray :: SNat n -> String -> [CExpr]  compileShapeTupIntoArray = \n var -> map CELit (toList (go n var))    where | 
