diff options
Diffstat (limited to 'src/Compile.hs')
| -rw-r--r-- | src/Compile.hs | 205 | 
1 files changed, 128 insertions, 77 deletions
| diff --git a/src/Compile.hs b/src/Compile.hs index 503c342..6ba3a39 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -1001,16 +1001,6 @@ compile' env = \case      return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]    EAccum _ t prj eidx eval eacc -> do -    nameidx <- compileAssign "acidx" env eidx -    nameval <- compileAssign "acval" env eval - -    -- Generate the variable manually because this one has to be non-const. -    -- TODO: old code: -    --   eacc' <- compile' env eacc -    --   nameacc <- genName' "acac" -    --   emit $ SVarDecl False (repSTy (typeOf eacc)) nameacc eacc' -    nameacc <- compileAssign "acac" env eacc -      let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a          -- full zero array with the given zero info (for the type SMTArr n t1).          initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM () @@ -1041,77 +1031,66 @@ compile' env = \case          initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi          initZeroFromMemset SMTScal{} = Nothing -    let -- initZero (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) -        initZero :: SMTy a -> String -> String -> CompM () -        initZero SMTNil _ _ = return () -        initZero (SMTPair t1 t2) v vzi = do -          initZero t1 (v++".a") (vzi++".a") -          initZero t2 (v++".b") (vzi++".b") -        initZero SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0") -        initZero SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0") -        initZero (SMTArr n t1) v vzi = initZeroArray n t1 v vzi -        initZero (SMTScal sty) v _ = case sty of +    let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type) +        initZeroZI :: SMTy a -> String -> String -> CompM () +        initZeroZI SMTNil _ _ = return () +        initZeroZI (SMTPair t1 t2) v vzi = do +          initZeroZI t1 (v++".a") (vzi++".a") +          initZeroZI t2 (v++".b") (vzi++".b") +        initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0") +        initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0") +        initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi +        initZeroZI (SMTScal sty) v _ = case sty of            STI32 -> emit $ SAsg v (CELit "0")            STI64 -> emit $ SAsg v (CELit "0l")            STF32 -> emit $ SAsg v (CELit "0.0f")            STF64 -> emit $ SAsg v (CELit "0.0") -    let -- | Dereference an accumulation value. Sparse components encountered -        -- along the way are initialised before proceeding downwards. At the -        -- point where we have the projected accumulator position available, -        -- the handler will be invoked with a variable name pointing to the -        -- projected position. -        -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (handler) -        accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> (String -> CompM ()) -> CompM () -        accumRef _ SAPHere v _ k = k v - -        accumRef (SMTPair ta _) (SAPFst prj') v i k = accumRef ta prj' (v++".a") (i++".a") k -        accumRef (SMTPair _ tb) (SAPSnd prj') v i k = accumRef tb prj' (v++".b") (i++".b") k - -        accumRef (SMTLEither ta _) (SAPLeft prj') v i k = do -          ((), stmtsInit1) <- scope $ initZero ta (v++".l") i -          emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) -                   (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty -          accumRef ta prj' (v++".l") i k -        accumRef (SMTLEither _ tb) (SAPRight prj') v i k = do -          ((), stmtsInit2) <- scope $ initZero tb (v++".r") i -          emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) -                   (pure (SAsg (v++".tag") (CELit "2")) <> stmtsInit2) mempty -          accumRef tb prj' (v++".r") i k - -        accumRef (SMTMaybe tj) (SAPJust prj') v i k = do -          ((), stmtsInit1) <- scope $ initZero tj (v++".j") i -          emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) -                   (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty -          accumRef tj prj' (v++".j") i k - -        accumRef (SMTArr n t') (SAPArrIdx prj') v i k = do -          when emitChecks $ do -            let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" -            forM_ (zip3 [0::Int ..] -                        (indexTupleComponents n (i++".a.a")) -                        (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do -              let a .||. b = CEBinop a "||" b -              emit $ SIf (CEBinop ixcomp "<" (CELit "0") -                          .||. -                          CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) -                          .||. -                          CEBinop shcomp "!=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))) -                       (pure $ SVerbatim $ -                          "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ -                          "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ -                          v ++ ".buf" ++ -                          concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ -                          concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ -                          concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++ -                          "); " ++ -                          "return false;") -                       mempty - -          accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") k +    let -- Initialise an uninitialised accumulation value, potentially already +        -- with the addend, potentially to zero depending on the nature of the +        -- projection. +        -- 1. If the projection indexes only through dense monoids before +        --    reaching SAPHere, the thing cannot be initialised to zero with +        --    only an AcIdx; it would need to model a zero after the addend, +        --    which is stupid and redundant. In this case, we return Left: +        --    (accumulation value) (AcIdx value) (addend value). +        --    The addend is copied, not consumed. (We can't reliably _always_ +        --    consume it, so it's not worth trying to do it sometimes.) +        -- 2. Otherwise, a sparse monoid is found along the way, and we can +        --    initalise the dense prefix of the path to zero by setting the +        --    indexed-through sparse value to a sparse zero. Afterwards, the +        --    main recursion can proceed further. In this case, we return +        --    Right: (accumulation value) (AcIdx value) +        -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type) +        initZeroChunk :: SMTy a -> SAcPrj p a b +                      -> Either (String -> String -> String -> CompM ())  -- dense initialisation with addend +                                (String -> String -> CompM ())  -- zero initialisation of sparse chunk +        initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of +          -- reached target before the first sparse constructor +          (t1           , SAPHere    ) -> Left $ \v _ addend -> do +                                            incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend +                                            emit $ SAsg v (CELit addend) +          -- sparse types +          (SMTLEither{} , _          ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") +          (SMTMaybe{}   , _          ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0") +          -- dense types +          (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do +                                            f (v++".a") (i++".a") +                                            initZeroZI t2 (v++".b") (i++".b") +          (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do +                                            initZeroZI t1 (v++".a") (i++".a") +                                            f (v++".b") (i++".b") +          (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do +                                             initZeroArray n t1 v (i++".a.b") +                                             linidxvar <- genName' "li" +                                             emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a")) +                                             f (v++".buf->xs["++linidxvar++"]") (i++".b") +          where +            applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i +            applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i      let -- Add a value (s) into an existing accumulation value (d). If a sparse -        -- component of d is encountered, s is simply written there. +        -- component of d is encountered, s is copied there.          add :: SMTy a -> String -> String -> CompM ()          add SMTNil _ _ = return ()          add (SMTPair t1 t2) d s = do @@ -1165,16 +1144,88 @@ compile' env = \case                         mempty            shsizename <- genName' "acshsz" -          emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n (s++".j")) +          emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s)            ivar <- genName' "i"            ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]")            emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename)                     stmts1          add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";" +    let -- | Dereference an accumulation value and add a given value to that +        -- position. Sparse components encountered along the way are +        -- initialised before proceeding downwards. +        -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there) +        accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM () +        accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend + +        accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend +        accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend + +        accumRef (SMTLEither ta tb) prj0 v i addend = do +          let chunkres = case prj0 of SAPLeft  prj' -> initZeroChunk ta prj' +                                      SAPRight prj' -> initZeroChunk tb prj' +              subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r") +              tagval = case prj0 of SAPLeft{}  -> "1" +                                    SAPRight{} -> "2" +          ((), stmtsAdd) <- scope $ case prj0 of SAPLeft  prj' -> accumRef ta prj' subv i addend +                                                 SAPRight prj' -> accumRef tb prj' subv i addend +          case chunkres of +            Left densef -> do +              ((), stmtsSet) <- scope $ densef subv i addend +              emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) +                       (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet) +                       stmtsAdd  -- TODO: emit check for consistency of tags? +            Right sparsef -> do +              ((), stmtsInit) <- scope $ sparsef subv i +              emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) +                       (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty +              forM_ stmtsAdd emit + +        accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do +          case initZeroChunk tj prj' of +            Left densef -> do +              ((), stmtsSet1) <- scope $ densef (v++".j") i addend +              ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend +              emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) +                       (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1) +                       stmtsAdd1 +            Right sparsef -> do +              ((), stmtsInit1) <- scope $ sparsef (v++".j") i +              emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0")) +                       (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty +              accumRef tj prj' (v++".j") i addend + +        accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do +          when emitChecks $ do +            let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" +            forM_ (zip3 [0::Int ..] +                        (indexTupleComponents n (i++".a.a")) +                        (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do +              let a .||. b = CEBinop a "||" b +              emit $ SIf (CEBinop ixcomp "<" (CELit "0") +                          .||. +                          CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) +                          .||. +                          CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))) +                       (pure $ SVerbatim $ +                          "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ +                          "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++ +                          v ++ ".buf" ++ +                          concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ +                          concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++ +                          concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++ +                          "); " ++ +                          "return false;") +                       mempty + +          accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend + +    nameidx <- compileAssign "acidx" env eidx +    nameval <- compileAssign "acval" env eval +    nameacc <- compileAssign "acac" env eacc +      emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")" -    accumRef t prj (nameacc++".buf->ac") nameidx $ \dest -> -      add (acPrjTy prj t) dest nameval +    accumRef t prj (nameacc++".buf->ac") nameidx nameval      emit $ SVerbatim $ "// compile EAccum end"      incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval | 
