diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/AST.hs | 54 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 32 | ||||
-rw-r--r-- | src/CHAD.hs | 22 | ||||
-rw-r--r-- | src/Data.hs | 4 | ||||
-rw-r--r-- | src/ForwardAD/DualNumbers.hs | 2 | ||||
-rw-r--r-- | src/Interpreter.hs | 34 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 22 | ||||
-rw-r--r-- | src/Language.hs | 2 |
8 files changed, 136 insertions, 36 deletions
@@ -147,9 +147,11 @@ data SOp a t where ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a) OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) - OEq :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) ONot :: SOp (TScal TBool) (TScal TBool) - OIf :: SOp (TScal TBool) (TEither TNil TNil) + OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) + OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) + OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right ORound64 :: SOp (TScal TF64) (TScal TI64) OToFl64 :: SOp (TScal TI64) (TScal TF64) deriving instance Show (SOp a t) @@ -163,6 +165,8 @@ opt2 = \case OLe _ -> STScal STBool OEq _ -> STScal STBool ONot -> STScal STBool + OAnd -> STScal STBool + OOr -> STScal STBool OIf -> STEither STNil STNil ORound64 -> STScal STI64 OToFl64 -> STScal STF64 @@ -206,23 +210,23 @@ typeOf = \case EError t _ -> t -unSNat :: SNat n -> Nat -unSNat SZ = Z -unSNat (SS n) = S (unSNat n) +-- unSNat :: SNat n -> Nat +-- unSNat SZ = Z +-- unSNat (SS n) = S (unSNat n) -unSTy :: STy t -> Ty -unSTy = \case - STNil -> TNil - STPair a b -> TPair (unSTy a) (unSTy b) - STEither a b -> TEither (unSTy a) (unSTy b) - STMaybe t -> TMaybe (unSTy t) - STArr n t -> TArr (unSNat n) (unSTy t) - STScal t -> TScal (unSScalTy t) - STAccum t -> TAccum (unSTy t) +-- unSTy :: STy t -> Ty +-- unSTy = \case +-- STNil -> TNil +-- STPair a b -> TPair (unSTy a) (unSTy b) +-- STEither a b -> TEither (unSTy a) (unSTy b) +-- STMaybe t -> TMaybe (unSTy t) +-- STArr n t -> TArr (unSNat n) (unSTy t) +-- STScal t -> TScal (unSScalTy t) +-- STAccum t -> TAccum (unSTy t) -unSList :: SList STy env -> [Ty] -unSList SNil = [] -unSList (SCons t l) = unSTy t : unSList l +-- unSEnv :: SList STy env -> [Ty] +-- unSEnv SNil = [] +-- unSEnv (SCons t l) = unSTy t : unSEnv l unSScalTy :: SScalTy t -> ScalTy unSScalTy = \case @@ -335,9 +339,25 @@ sscaltyKnown STF32 = Dict sscaltyKnown STF64 = Dict sscaltyKnown STBool = Dict +envKnown :: SList STy env -> Dict (KnownEnv env) +envKnown SNil = Dict +envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict + ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) $ let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) (EFst ext arg) + +eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) +eidxEq SZ _ _ = EConst ext STBool True +eidxEq (SS n) a b + | let ty = tTup (sreplicate (SS n) tIx) + = ELet ext a $ + ELet ext (weakenExpr WSink b) $ + EOp ext OAnd $ EPair ext + (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ))) + (ESnd ext (EVar ext ty IZ)))) + (eidxEq n (EFst ext (EVar ext ty (IS IZ))) + (EFst ext (EVar ext ty IZ))) diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index b50506a..acd0dc3 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -6,7 +6,7 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeOperators #-} -module AST.Pretty (ppExpr) where +module AST.Pretty (ppExpr, ppTy) where import Control.Monad (ap) import Data.List (intersperse) @@ -42,10 +42,10 @@ genNameIfUsedIn' prefix ty idx ex genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn = genNameIfUsedIn' "x" -ppExpr :: SList STy env -> Expr x env t -> String +ppExpr :: SList f env -> Expr x env t -> String ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) "" where - mkVal :: SList STy env -> M (SVal env) + mkVal :: SList f env -> M (SVal env) mkVal SNil = return SNil mkVal (SCons _ v) = do val <- mkVal v @@ -112,14 +112,14 @@ ppExpr' d val = \case EBuild1 _ a b -> do a' <- ppExpr' 11 val a - name <- genNameIfUsedIn (STScal STI64) IZ b + name <- genNameIfUsedIn' "i" (STScal STI64) IZ b b' <- ppExpr' 0 (Const name `SCons` val) b return $ showParen (d > 10) $ showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")" EBuild _ n a b -> do a' <- ppExpr' 11 val a - name <- genNameIfUsedIn (tTup (sreplicate n tIx)) IZ b + name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b e' <- ppExpr' 0 (Const name `SCons` val) b return $ showParen (d > 10) $ showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")" @@ -195,7 +195,7 @@ ppExpr' d val = \case e2' <- ppExpr' 11 val e2 e3' <- ppExpr' 11 val e3 return $ showParen (d > 10) $ - showString ("accum " ++ show (unSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3' + showString ("accum " ++ show (fromSNat i) ++ " ") . e1' . showString " " . e2' . showString " " . e3' EZero _ -> return $ showString "zero" @@ -243,6 +243,26 @@ operator OLt{} = (Infix, "<") operator OLe{} = (Infix, "<=") operator OEq{} = (Infix, "==") operator ONot = (Prefix, "not") +operator OAnd = (Infix, "&&") +operator OOr = (Infix, "||") operator OIf = (Prefix, "ifB") operator ORound64 = (Prefix, "round") operator OToFl64 = (Prefix, "toFl64") + +ppTy :: Int -> STy t -> String +ppTy d ty = ppTys d ty "" + +ppTys :: Int -> STy t -> ShowS +ppTys _ STNil = showString "1" +ppTys d (STPair a b) = showParen (d > 7) $ ppTys 8 a . showString " * " . ppTys 8 b +ppTys d (STEither a b) = showParen (d > 6) $ ppTys 7 a . showString " + " . ppTys 7 b +ppTys d (STMaybe t) = showParen (d > 10) $ showString "Maybe " . ppTys 11 t +ppTys d (STArr n t) = showParen (d > 10) $ + showString "Arr " . shows (fromSNat n) . showString " " . ppTys 11 t +ppTys _ (STScal sty) = showString $ case sty of + STI32 -> "i32" + STI64 -> "i64" + STF32 -> "f32" + STF64 -> "f64" + STBool -> "bool" +ppTys d (STAccum t) = showParen (d > 10) $ showString "Accum " . ppTys 11 t diff --git a/src/CHAD.hs b/src/CHAD.hs index e77dbe7..dda434c 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -545,6 +545,8 @@ d1op (OLt t) e = EOp ext (OLt t) e d1op (OLe t) e = EOp ext (OLe t) e d1op (OEq t) e = EOp ext (OEq t) e d1op ONot e = EOp ext ONot e +d1op OAnd e = EOp ext OAnd e +d1op OOr e = EOp ext OOr e d1op OIf e = EOp ext OIf e d1op ORound64 e = EOp ext ORound64 e d1op OToFl64 e = EOp ext OToFl64 e @@ -564,6 +566,8 @@ d2op op = case op of OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) ONot -> Linear $ \_ -> ENil ext + OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) OIf -> Linear $ \_ -> ENil ext ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 OToFl64 -> Linear $ \_ -> ENil ext @@ -1078,15 +1082,19 @@ drev des = \case | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil , STArr n eltty <- typeOf e - , Refl <- indexTupD1Id n -> + , Refl <- indexTupD1Id n + , let tIxN = tTup (sreplicate n tIx) -> Ret (binds `BPush` (STArr n (d1 eltty), e1) - `BPush` (tTup (sreplicate n tIx), EShape ext (EVar ext (typeOf e1) IZ))) - (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS IZ)) - (weakenExpr (WSink .> WSink) ei1)) + `BPush` (tIxN, EShape ext (EVar ext (typeOf e1) IZ)) + `BPush` (tIxN, weakenExpr (WSink .> WSink) ei1)) + (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) sub - (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS IZ)) - (EVar ext (d2 eltty) (IS IZ))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + (ELet ext (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) (IS IZ)) $ + ECase ext (EOp ext OIf (eidxEq n (EVar ext tIxN IZ) (EVar ext tIxN (IS (IS IZ))))) + (EVar ext (d2 eltty) (IS (IS IZ))) + (EZero eltty)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) e2) EShape _ e -- Allowed to ignore e2 here because the output of EShape is discrete, diff --git a/src/Data.hs b/src/Data.hs index e951ef2..c5d6219 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -30,6 +30,10 @@ slistMap :: (forall t. f t -> g t) -> SList f list -> SList g list slistMap _ SNil = SNil slistMap f (SCons x list) = SCons (f x) (slistMap f list) +unSList :: (forall t. f t -> a) -> SList f list -> [a] +unSList _ SNil = [] +unSList f (x `SCons` l) = f x : unSList f l + sappend :: SList f l1 -> SList f l2 -> SList f (Append l1 l2) sappend SNil l = l sappend (SCons x xs) l = SCons x (sappend xs l) diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index f02b93e..4f84e8d 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -99,6 +99,8 @@ dop = \case (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y))) (EOp ext (OEq t)) ONot -> EOp ext ONot + OAnd -> EOp ext OAnd + OOr -> EOp ext OOr OIf -> EOp ext OIf ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg) OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0) diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 3fb5d7b..3d6f33d 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -3,6 +3,7 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImplicitParams #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} @@ -23,13 +24,14 @@ import Data.Char (isSpace) import Data.Kind (Type) import Data.Int (Int64) import Data.IORef -import GHC.Stack (HasCallStack) +import System.IO (hPutStrLn, stderr) import System.IO.Unsafe (unsafePerformIO) import Debug.Trace import Array import AST +import AST.Pretty import CHAD.Types import Data import Interpreter.Rep @@ -42,14 +44,25 @@ newtype AcM s a = AcM { unAcM :: IO a } runAcM :: (forall s. AcM s a) -> a runAcM (AcM m) = unsafePerformIO m +acmDebugLog :: String -> AcM s () +acmDebugLog s = AcM (hPutStrLn stderr s) + interpret :: Ex '[] t -> Rep t interpret = interpretOpen SNil interpretOpen :: SList Value env -> Ex env t -> Rep t -interpretOpen env e = runAcM (interpret' env e) - -interpret' :: forall env t s. HasCallStack => SList Value env -> Ex env t -> AcM s (Rep t) -interpret' env = \case +interpretOpen env e = runAcM (let ?depth = 0 in interpret' env e) + +interpret' :: forall env t s. (?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t) +interpret' env e = do + let dep = ?depth + acmDebugLog $ replicate dep ' ' ++ "ev: " ++ ppExpr env e + res <- let ?depth = dep + 1 in interpret'Rec env e + acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res "" + return res + +interpret'Rec :: forall env t s. (?depth :: Int) => SList Value env -> Ex env t -> AcM s (Rep t) +interpret'Rec env = \case EVar _ _ i -> case slistIdx env i of Value x -> return x ELet _ a b -> do x <- interpret' env a @@ -125,11 +138,20 @@ interpretOp op arg = case op of ONeg st -> numericIsNum st $ negate arg OLt st -> numericIsNum st $ uncurry (<) arg OLe st -> numericIsNum st $ uncurry (<=) arg - OEq st -> numericIsNum st $ uncurry (==) arg + OEq st -> styIsEq st $ uncurry (==) arg ONot -> not arg + OAnd -> uncurry (&&) arg + OOr -> uncurry (||) arg OIf -> if arg then Left () else Right () ORound64 -> round arg OToFl64 -> fromIntegral arg + where + styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r + styIsEq STI32 = id + styIsEq STI64 = id + styIsEq STF32 = id + styIsEq STF64 = id + styIsEq STBool = id zeroD2 :: STy t -> Rep (D2 t) zeroD2 typ = case typ of diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs index 5c20183..baf38fc 100644 --- a/src/Interpreter/Rep.hs +++ b/src/Interpreter/Rep.hs @@ -3,6 +3,8 @@ {-# LANGUAGE UndecidableInstances #-} module Interpreter.Rep where +import Data.List (intersperse) +import Data.Foldable (toList) import Data.IORef import GHC.TypeError @@ -53,3 +55,23 @@ vPair = liftV2 (,) vUnpair :: Value (TPair a b) -> (Value a, Value b) vUnpair (Value (x, y)) = (Value x, Value y) + +showValue :: Int -> STy t -> Rep t -> ShowS +showValue _ STNil () = showString "()" +showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")" +showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Left " . showValue 11 a x +showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Right " . showValue 11 b y +showValue _ (STMaybe _) Nothing = showString "Nothing" +showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x +showValue d (STArr _ t) arr = showParen (d > 10) $ + showString "arrayFromList " . showsPrec 11 (arrayShape arr) + . showString " [" + . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr)) + . showString "]" +showValue _ (STScal sty) x = case sty of + STF32 -> shows x + STF64 -> shows x + STI32 -> shows x + STI64 -> shows x + STBool -> shows x +showValue _ STAccum{} _ = error "Cannot show accumulators" diff --git a/src/Language.hs b/src/Language.hs index a025236..3a4a36c 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -126,5 +126,7 @@ not_ = oper ONot -- | The "_" variables in scope are unusable and should be ignored. With a -- weakening function on NExprs they could be hidden. +-- +-- The first alternative is the True case; the second is the False case. if_ :: NExpr env (TScal TBool) -> NExpr ('("_", TNil) : env) t -> NExpr ('("_", TNil) : env) t -> NExpr env t if_ e a b = case_ (oper OIf e) (#_ :-> a) (#_ :-> b) |