diff options
-rw-r--r-- | chad-fast.cabal | 3 | ||||
-rw-r--r-- | src/AST.hs | 49 | ||||
-rw-r--r-- | src/AST/Count.hs | 6 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 29 | ||||
-rw-r--r-- | src/Array.hs (renamed from src/Interpreter/Array.hs) | 26 | ||||
-rw-r--r-- | src/CHAD.hs | 11 | ||||
-rw-r--r-- | src/Example.hs | 20 | ||||
-rw-r--r-- | src/Interpreter.hs | 152 | ||||
-rw-r--r-- | src/Interpreter/Accum.hs | 76 | ||||
-rw-r--r-- | src/Interpreter/Rep.hs | 17 | ||||
-rw-r--r-- | src/Language.hs | 38 | ||||
-rw-r--r-- | src/Language/AST.hs | 13 | ||||
-rw-r--r-- | src/Simplify.hs | 12 |
13 files changed, 362 insertions, 90 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index ef1fd66..6314acd 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -10,6 +10,7 @@ build-type: Simple library exposed-modules: + Array AST AST.Count AST.Env @@ -22,7 +23,7 @@ library Example Interpreter Interpreter.Accum - Interpreter.Array + Interpreter.Rep Language Language.AST Lemmas @@ -20,6 +20,7 @@ import Data.Kind (Type) import Data.Int import Data.Type.Equality +import Array import AST.Env import AST.Weaken import Data @@ -91,6 +92,13 @@ type family ScalRep t where ScalRep TF64 = Double ScalRep TBool = Bool +type family ScalIsNumeric t where + ScalIsNumeric TI32 = True + ScalIsNumeric TI64 = True + ScalIsNumeric TF32 = True + ScalIsNumeric TF64 = True + ScalIsNumeric TBool = False + -- | This index is flipped around from the usual direction: the smallest index -- is at the _heart_ of the nesting, not at the outside. The outermost layer -- indexes into the _outer_ dimension of the type @t@. This makes indices into @@ -128,11 +136,13 @@ data Expr x env t where ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c -- array operations + EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t) EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) - EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) + EFold1Inner :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) + ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) - -- EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) -- TODO: unused + EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) @@ -160,13 +170,16 @@ type family Tup env where Tup '[] = TNil Tup (t : ts) = TPair (Tup ts) t +mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b)) + -> SList f list -> f (Tup list) +mkTup nil _ SNil = nil +mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e + tTup :: SList STy env -> STy (Tup env) -tTup SNil = STNil -tTup (SCons t ts) = STPair (tTup ts) t +tTup = mkTup STNil STPair eTup :: SList (Ex env) list -> Ex env (Tup list) -eTup SNil = ENil ext -eTup (e `SCons` es) = EPair ext (eTup es) e +eTup = mkTup (ENil ext) (EPair ext) type family InvTup core env where InvTup core '[] = core @@ -174,12 +187,12 @@ type family InvTup core env where type SOp :: Ty -> Ty -> Type data SOp a t where - OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) - OMul :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) - ONeg :: SScalTy a -> SOp (TScal a) (TScal a) - OLt :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) - OLe :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) - OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a) + OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + OEq :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) ONot :: SOp (TScal TBool) (TScal TBool) OIf :: SOp (TScal TBool) (TEither TNil TNil) deriving instance Show (SOp a t) @@ -208,11 +221,13 @@ typeOf = \case EInr _ t1 e -> STEither t1 (typeOf e) ECase _ _ a _ -> typeOf a + EConstArr _ n t _ -> STArr n (STScal t) EBuild1 _ _ e -> STArr (SS SZ) (typeOf e) EBuild _ n _ e -> STArr n (typeOf e) - EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t + EFold1Inner _ _ e | STArr (SS n) t <- typeOf e -> STArr n t + ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EUnit _ e -> STArr SZ (typeOf e) - -- EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t + EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t @@ -273,11 +288,13 @@ subst' f w = \case EInl x t e -> EInl x t (subst' f w e) EInr x t e -> EInr x t (subst' f w e) ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) + EConstArr x n t a -> EConstArr x n t a EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b) EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) - EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) + EFold1Inner x a b -> EFold1Inner x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) + ESum1Inner x e -> ESum1Inner x (subst' f w e) EUnit x e -> EUnit x (subst' f w e) - -- EReplicate x e -> EReplicate x (subst' f w e) + EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 39d26c2..6a00e83 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -109,11 +109,13 @@ occCountGeneral onehot unpush alter many = go WId EInl _ _ e -> re e EInr _ _ e -> re e ECase _ e a b -> re e <> (re1 a `alter` re1 b) + EConstArr{} -> mempty EBuild1 _ a b -> re a <> many (re1 b) EBuild _ _ a b -> re a <> many (re1 b) - EFold1 _ a b -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b + EFold1Inner _ a b -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b + ESum1Inner _ e -> re e EUnit _ e -> re e - -- EReplicate _ e -> re e + EReplicate1Inner _ a b -> re a <> re b EConst{} -> mempty EIdx0 _ e -> re e EIdx1 _ a b -> re a <> re b diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index bf0d350..2ce883b 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -42,11 +42,6 @@ genNameIfUsedIn' prefix ty idx ex genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String genNameIfUsedIn = genNameIfUsedIn' "x" -valprj :: SList f env -> Idx env t -> f t -valprj (x `SCons` _) IZ = x -valprj (_ `SCons` env) (IS i) = valprj env i -valprj SNil i = case i of {} - ppExpr :: SList STy env -> Expr x env t -> String ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) "" where @@ -59,7 +54,7 @@ ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) "" ppExpr' :: Int -> SVal env -> Expr x env t -> M ShowS ppExpr' d val = \case - EVar _ _ i -> return $ showString $ getConst $ valprj val i + EVar _ _ i -> return $ showString $ getConst $ slistIdx val i e@ELet{} -> ppExprLet d val e @@ -97,6 +92,9 @@ ppExpr' d val = \case showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a' . showString (" ; Inr " ++ name2 ++ " -> ") . b' . showString " }" + EConstArr _ _ ty v + | Dict <- scalRepIsShow ty -> return $ showsPrec d v + EBuild1 _ a b -> do a' <- ppExpr' 11 val a name <- genNameIfUsedIn (STScal STI64) IZ b @@ -111,25 +109,30 @@ ppExpr' d val = \case return $ showParen (d > 10) $ showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")" - EFold1 _ a b -> do + EFold1Inner _ a b -> do name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a name2 <- genNameIfUsedIn (typeOf a) IZ a a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a b' <- ppExpr' 11 val b return $ showParen (d > 10) $ - showString ("fold1 (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a' + showString ("fold1i (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a' . showString ") " . b' + ESum1Inner _ e -> do + e' <- ppExpr' 11 val e + return $ showParen (d > 10) $ showString "sum1i " . e' + EUnit _ e -> do e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString "unit " . e' - -- EReplicate _ e -> do - -- e' <- ppExpr' 11 val e - -- return $ showParen (d > 10) $ showString "replicate " . e' + EReplicate1Inner _ a b -> do + a' <- ppExpr' 11 val a + b' <- ppExpr' 11 val b + return $ showParen (d > 10) $ showString "replicate1i " . a' . showString " " . b' - EConst _ ty v -> return $ showString $ case ty of - STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v + EConst _ ty v + | Dict <- scalRepIsShow ty -> return $ showsPrec d v EIdx0 _ e -> do e' <- ppExpr' 11 val e diff --git a/src/Interpreter/Array.hs b/src/Array.hs index 54e0791..9a770c4 100644 --- a/src/Interpreter/Array.hs +++ b/src/Array.hs @@ -1,8 +1,9 @@ {-# LANGUAGE KindSignatures #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TupleSections #-} -module Interpreter.Array where +module Array where import Control.Monad.Trans.State.Strict import Data.Foldable (traverse_) @@ -15,18 +16,32 @@ import Data data Shape n where ShNil :: Shape Z ShCons :: Shape n -> Int -> Shape (S n) +deriving instance Show (Shape n) data Index n where IxNil :: Index Z IxCons :: Index n -> Int -> Index (S n) +deriving instance Show (Index n) shapeSize :: Shape n -> Int shapeSize ShNil = 0 shapeSize (ShCons sh n) = shapeSize sh * n +fromLinearIndex :: Shape n -> Int -> Index n +fromLinearIndex ShNil 0 = IxNil +fromLinearIndex ShNil _ = error "Index out of range" +fromLinearIndex (sh `ShCons` n) i = + let (q, r) = i `quotRem` n + in fromLinearIndex sh q `IxCons` r + +toLinearIndex :: Shape n -> Index n -> Int +toLinearIndex ShNil IxNil = 0 +toLinearIndex (sh `ShCons` n) (idx `IxCons` i) = toLinearIndex sh idx * n + i + -- | TODO: this Vector is a boxed vector, which is horrendously inefficient. data Array (n :: Nat) t = Array (Shape n) (Vector t) + deriving (Show) arrayShape :: Array n t -> Shape n arrayShape (Array sh _) = sh @@ -34,9 +49,18 @@ arrayShape (Array sh _) = sh arraySize :: Array n t -> Int arraySize (Array sh _) = shapeSize sh +arrayIndex :: Array n t -> Index n -> t +arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx) + arrayIndexLinear :: Array n t -> Int -> t arrayIndexLinear (Array _ v) i = v V.! i +arrayIndex1 :: Array (S n) t -> Int -> Array n t +arrayIndex1 (Array (sh `ShCons` _) v) i = let sz = shapeSize sh in Array sh (V.slice (sz * i) sz v) + +arrayGenerateM :: Monad m => Shape n -> (Index n -> m t) -> m (Array n t) +arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh) + arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t) arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f diff --git a/src/CHAD.hs b/src/CHAD.hs index 692bb96..943f0a2 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -968,6 +968,12 @@ drev des = \case (subenvNone (select SMerge des)) (ENil ext) + EConstArr _ n t val -> + Ret BTop + (EConstArr ext n t val) + (subenvNone (select SMerge des)) + (ENil ext) + EBuild1 _ ne (orige :: Ex _ eltty) | Ret (ne0 :: Bindings _ _ ne_binds) ne1 _ _ <- drev des ne -- allowed to ignore ne2 here because ne has a discrete result , let eltty = typeOf orige -> @@ -1075,9 +1081,10 @@ drev des = \case weakenExpr (WCopy (WSink .> WSink)) e2) -- These should be the next to be implemented, I think - EFold1{} -> err_unsupported "EFold1" + ESum1Inner{} -> err_unsupported "ESum" + EReplicate1Inner{} -> err_unsupported "EReplicate" EShape{} -> err_unsupported "EShape" - -- EReplicate{} -> err_unsupported "EReplicate" + EFold1Inner{} -> err_unsupported "EFold1Inner" EIdx{} -> err_unsupported "EIdx" EBuild{} -> err_unsupported "EBuild" diff --git a/src/Example.hs b/src/Example.hs index 4130f47..d1d04e3 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -140,3 +140,23 @@ ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $ in let_ #parstup #pars123 $ let_ #inp #input $ layer (STPair (STPair (STPair STNil tpair) tpair) tpair) + +type TVec = TArr (S Z) +type TMat = TArr (S (S Z)) + +neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R +neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $ + let layer :: (Lookup "wei" env ~ TMat R, Lookup "bias" env ~ TVec R, Lookup "x" env ~ TVec R) => NExpr env (TVec R) + layer = + -- prod = wei `matmul` x + let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :-> + #wei ! #idx * #x ! pair nil (snd_ #idx)) $ + -- relu (prod + bias) + build (SS SZ) (shape #prod) $ #idx :-> + let_ #out (#prod ! #idx + #bias ! #idx) $ + if_ (#out .<= const_ 0) (const_ 0) #out + + in let_ #x1 (let_ #wei (fst_ #layer1) $ let_ #bias (snd_ #layer1) $ let_ #x #input $ layer) $ + let_ #x2 (let_ #wei (fst_ #layer2) $ let_ #bias (snd_ #layer2) $ let_ #x #x1 $ layer) $ + let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $ + #x3 ! nil diff --git a/src/Interpreter.hs b/src/Interpreter.hs index afc50f9..7ffb14c 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -1,8 +1,156 @@ -module Interpreter where +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE FlexibleContexts #-} +module Interpreter ( + interpret, + interpret', + Value, + NoAccum(..), + unAccum, +) where + +import Data.Int (Int64) +import Data.Proxy import AST -import Interpreter.Array +import Data +import Array import Interpreter.Accum +import Interpreter.Rep +import Control.Monad (foldM) + + +interpret :: NoAccum t => Ex '[] t -> Rep t +interpret e = runAcM (go e) + where + go :: forall s t. NoAccum t => Ex '[] t -> AcM s (Rep t) + go e' | Refl <- noAccum (Proxy @s) (Proxy @t) = interpret' SNil e' + +newtype Value s t = Value (Rep' s t) + +interpret' :: forall env t s. SList (Value s) env -> Ex env t -> AcM s (Rep' s t) +interpret' env = \case + EVar _ _ i -> case slistIdx env i of Value x -> return x + ELet _ a b -> do + x <- interpret' env a + interpret' (Value x `SCons` env) b + EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b + EFst _ e -> fst <$> interpret' env e + ESnd _ e -> snd <$> interpret' env e + ENil _ -> return () + EInl _ _ e -> Left <$> interpret' env e + EInr _ _ e -> Right <$> interpret' env e + ECase _ e a b -> interpret' env e >>= \case + Left x -> interpret' (Value x `SCons` env) a + Right y -> interpret' (Value y `SCons` env) b + EConstArr _ _ _ v -> return v + EBuild1 _ a b -> do + n <- fromIntegral @Int64 @Int <$> interpret' env a + arrayGenerateLinM (ShNil `ShCons` n) + (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b) + EBuild _ dim a b -> do + sh <- unTupRepIdx (Proxy @s) ShNil ShCons dim <$> interpret' env a + arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx (Proxy @s) ixUncons dim idx) `SCons` env) b) + EFold1Inner _ a b -> do + let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a + arr <- interpret' env b + let sh `ShCons` n = arrayShape arr + arrayGenerateM sh $ \idx -> foldl1M f [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + ESum1Inner _ e -> do + arr <- interpret' env e + let STArr _ (STScal t) = typeOf e + sh `ShCons` n = arrayShape arr + numericIsNum t $ arrayGenerateM sh $ \idx -> return $ sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e) + EReplicate1Inner _ a b -> do + n <- fromIntegral @Int64 @Int <$> interpret' env a + arr <- interpret' env b + let sh = arrayShape arr + arrayGenerateM (sh `ShCons` n) (\(idx `IxCons` _) -> return (arrayIndex arr idx)) + EConst _ _ v -> return v + EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e + EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) + EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx (Proxy @s) IxNil IxCons n <$> interpret' env b) + EShape _ e | STArr n _ <- typeOf e -> tupRepIdx (Proxy @s) shUncons n . arrayShape <$> interpret' env e + EOp _ op e -> interpretOp (Proxy @s) op <$> interpret' env e + EWith e1 e2 -> do + initval <- interpret' env e1 + withAccum (typeOf e1) initval $ \accum -> + interpret' (Value accum `SCons` env) e2 + EAccum i e1 e2 e3 -> do + idx <- interpret' env e1 + val <- interpret' env e2 + accum <- interpret' env e3 + accumAdd accum i idx val + EError _ s -> error $ "Interpreter: Program threw error: " ++ s + +interpretOp :: Proxy s -> SOp a t -> Rep' s a -> Rep' s t +interpretOp _ op arg = case op of + OAdd st -> numericIsNum st $ uncurry (+) arg + OMul st -> numericIsNum st $ uncurry (*) arg + ONeg st -> numericIsNum st $ negate arg + OLt st -> numericIsNum st $ uncurry (<) arg + OLe st -> numericIsNum st $ uncurry (<=) arg + OEq st -> numericIsNum st $ uncurry (==) arg + ONot -> not arg + OIf -> if arg then Left () else Right () + +numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r +numericIsNum STI32 = id +numericIsNum STI64 = id +numericIsNum STF32 = id +numericIsNum STF64 = id + +unTupRepIdx :: Proxy s -> f Z -> (forall m. f m -> Int -> f (S m)) + -> SNat n -> Rep' s (Tup (Replicate n TIx)) -> f n +unTupRepIdx _ nil _ SZ _ = nil +unTupRepIdx p nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i + +tupRepIdx :: Proxy s -> (forall m. f (S m) -> (f m, Int)) + -> SNat n -> f n -> Rep' s (Tup (Replicate n TIx)) +tupRepIdx _ _ SZ _ = () +tupRepIdx p uncons (SS n) tup = + let (tup', i) = uncons tup + in (tupRepIdx p uncons n tup', fromIntegral @Int @Int64 i) + +ixUncons :: Index (S n) -> (Index n, Int) +ixUncons (IxCons idx i) = (idx, i) + +shUncons :: Shape (S n) -> (Shape n, Int) +shUncons (ShCons idx i) = (idx, i) +class NoAccum t where + noAccum :: Proxy s -> Proxy t -> Rep' s t :~: Rep t +instance NoAccum TNil where + noAccum _ _ = Refl +instance (NoAccum a, NoAccum b) => NoAccum (TPair a b) where + noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl +instance (NoAccum a, NoAccum b) => NoAccum (TEither a b) where + noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl +instance NoAccum t => NoAccum (TArr n t) where + noAccum p _ | Refl <- noAccum p (Proxy @t) = Refl +instance NoAccum (TScal t) where + noAccum _ _ = Refl +unAccum :: Proxy s -> STy t -> Maybe (Dict (NoAccum t)) +unAccum _ STNil = Just Dict +unAccum p (STPair t1 t2) + | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict + | otherwise = Nothing +unAccum p (STEither t1 t2) + | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict + | otherwise = Nothing +unAccum p (STArr _ t) + | Just Dict <- unAccum p t = Just Dict + | otherwise = Nothing +unAccum _ STScal{} = Just Dict +unAccum _ STAccum{} = Nothing +foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a +foldl1M _ [] = error "foldl1M: empty list" +foldl1M f (tophead : toptail) = foldM f tophead toptail diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs index d15ea10..b6a91df 100644 --- a/src/Interpreter/Accum.hs +++ b/src/Interpreter/Accum.hs @@ -15,7 +15,7 @@ module Interpreter.Accum ( AcM, runAcM, - Rep, + Rep', Accum, withAccum, accumAdd, @@ -25,6 +25,8 @@ module Interpreter.Accum ( import Control.Concurrent import Control.Monad (when, forM_) import Data.Bifunctor (second) +import Data.Proxy +import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) import Foreign.Storable (sizeOf) import GHC.Exts import GHC.Float @@ -33,10 +35,9 @@ import GHC.IO (IO(..)) import GHC.Word import System.IO.Unsafe (unsafePerformIO) +import Array import AST import Data -import Interpreter.Array -import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) newtype AcM s a = AcM (IO a) @@ -45,34 +46,35 @@ newtype AcM s a = AcM (IO a) runAcM :: (forall s. AcM s a) -> a runAcM (AcM m) = unsafePerformIO m -type family Rep t where - Rep TNil = () - Rep (TPair a b) = (Rep a, Rep b) - Rep (TEither a b) = Either (Rep a) (Rep b) - Rep (TArr n t) = Array n (Rep t) - Rep (TScal sty) = ScalRep sty - -- Rep (TAccum t) = _ +-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined. +type family Rep' s t where + Rep' s TNil = () + Rep' s (TPair a b) = (Rep' s a, Rep' s b) + Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b) + Rep' s (TArr n t) = Array n (Rep' s t) + Rep' s (TScal sty) = ScalRep sty + Rep' s (TAccum t) = Accum s t -- | Floats and integers are accumulated; booleans are left as-is. data Accum s t = Accum (STy t) (ForeignPtr ()) -tSize :: STy t -> Rep t -> Int -tSize ty x = tSize' ty (Just x) +tSize :: Proxy s -> STy t -> Rep' s t -> Int +tSize p ty x = tSize' p ty (Just x) -- | Passing Nothing as the value means "this is (inside) an array element". -tSize' :: STy t -> Maybe (Rep t) -> Int -tSize' typ val = case typ of +tSize' :: Proxy s -> STy t -> Maybe (Rep' s t) -> Int +tSize' p typ val = case typ of STNil -> 0 - STPair a b -> tSize' a (fst <$> val) + tSize' b (snd <$> val) + STPair a b -> tSize' p a (fst <$> val) + tSize' p b (snd <$> val) STEither a b -> case val of - Nothing -> 1 + max (tSize' a Nothing) (tSize' b Nothing) - Just (Left x) -> 1 + tSize a x -- '1 +' is for runtime sanity checking - Just (Right y) -> 1 + tSize b y -- idem + Nothing -> 1 + max (tSize' p a Nothing) (tSize' p b Nothing) + Just (Left x) -> 1 + tSize' p a (Just x) -- '1 +' is for runtime sanity checking + Just (Right y) -> 1 + tSize' p b (Just y) -- idem STArr ndim t -> case val of Nothing -> error "Nested arrays not supported in this implementation" - Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' t Nothing + Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing STScal sty -> goScal sty STAccum{} -> error "Nested accumulators unsupported" where @@ -86,10 +88,10 @@ tSize' typ val = case typ of -- | This operation does not commute with 'accumAdd', so it must be used with -- care. Furthermore it must be used on exactly the same value as tSize was -- called on. Hence it lives in IO, not in AcM. -accumWrite :: forall s t. Accum s t -> Rep t -> IO () +accumWrite :: forall s t. Accum s t -> Rep' s t -> IO () accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> let - go :: Bool -> STy t' -> Rep t' -> Int -> IO Int + go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int go inarr ty val off = case ty of STNil -> return off STPair a b -> do @@ -110,7 +112,7 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do off1 <- goShape (arrayShape val) off - let eltsize = tSize' t Nothing + let eltsize = tSize' (Proxy @s) t Nothing n = arraySize val traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val return (off1 + eltsize * n) @@ -136,10 +138,10 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> in () <$ go False topty top_value 0 -accumRead :: forall s t. Accum s t -> AcM s (Rep t) +accumRead :: forall s t. Accum s t -> AcM s (Rep' s t) accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> let - go :: Bool -> STy t' -> Int -> IO (Int, Rep t') + go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t') go inarr ty off = case ty of STNil -> return (off, ()) STPair a b -> do @@ -154,13 +156,13 @@ accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> 1 -> fmap Right <$> go inarr b (off + 1) _ -> error "Invalid tag in accum memory" if inarr - then return (off + 1 + max (tSize' a Nothing) (tSize' b Nothing), val) + then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val) else return (off1, val) STArr ndim t | inarr -> error "Nested arrays not supported in this implementation" | otherwise -> do (off1, sh) <- readShape addr# ndim off - let eltsize = tSize' t Nothing + let eltsize = tSize' (Proxy @s) t Nothing n = shapeSize sh arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) return (off1 + eltsize * n, arr) @@ -205,10 +207,10 @@ invertShape | Refl <- lemPlusZero @n = flip go IShNil go ShNil ish = ish go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) -accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s () +accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s () accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> let - go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> IO () + go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO () go inarr ty SZ () val off = () <$ performAdd inarr ty val off go inarr ty (SS dep) idx val off = case (ty, idx, val) of (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off @@ -227,23 +229,23 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr (STAccum{}, _, _) -> error "Nested accumulators unsupported" goArr :: SNat i' -> InvShape n -> STy t' - -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> IO () + -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO () goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do - let i' = fromIntegral @(Rep TIx) @Int i + let i' = fromIntegral @(Rep' s TIx) @Int i when (i' < 0 || i' >= n) $ error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")" goArr depm1 ish t1 idx val (off + i' * ishSize ish) - performAddArr :: Int -> STy t' -> Array n (Rep t') -> Int -> IO Int + performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int performAddArr arraySz eltty val off = do - let eltsize = tSize' eltty Nothing + let eltsize = tSize' (Proxy @s) eltty Nothing forM_ [0 .. arraySz - 1] $ \lini -> performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize) return (off + arraySz * eltsize) - performAdd :: Bool -> STy t' -> Rep t' -> Int -> IO Int + performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int performAdd inarr ty val off = case ty of STNil -> return off STPair t1 t2 -> do @@ -257,7 +259,7 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) _ -> error "accumAdd: Tag mismatch for Either" if inarr - then return (off + 1 + max (tSize' t1 Nothing) (tSize' t2 Nothing)) + then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing)) else return off1 STArr n ty' | inarr -> error "Nested array" @@ -300,18 +302,18 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr in () <$ go False topty top_depth top_index top_value 0 -withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b) +withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t) withAccum ty start fun = do -- The initial write must happen before any of the adds or reads, so it makes -- sense to put it in IO together with the allocation, instead of in AcM. - accum <- AcM $ do buffer <- mallocBytes (tSize ty start) + accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start) ptr <- newForeignPtr finalizerFree buffer let accum = Accum ty ptr accumWrite accum start return accum b <- fun accum out <- accumRead accum - return (out, b) + return (b, out) inParallel :: [AcM s t] -> AcM s [t] inParallel actions = AcM $ do diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs new file mode 100644 index 0000000..1ded773 --- /dev/null +++ b/src/Interpreter/Rep.hs @@ -0,0 +1,17 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeFamilies #-} +module Interpreter.Rep where + +import GHC.TypeError + +import Array +import AST + + +type family Rep t where + Rep TNil = () + Rep (TPair a b) = (Rep a, Rep b) + Rep (TEither a b) = Either (Rep a) (Rep b) + Rep (TArr n t) = Array n (Rep t) + Rep (TScal sty) = ScalRep sty + Rep (TAccum t) = TypeError (Text "Accumulator in Rep") diff --git a/src/Language.hs b/src/Language.hs index 58a7070..cdc6d6b 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -8,6 +8,7 @@ module Language ( Lookup, ) where +import Array import AST import Data import Language.AST @@ -49,18 +50,30 @@ inr = NEInr case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2 +constArr_ :: (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) +constArr_ x = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConstArr knownNat ty x + build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t) build1 a (v :-> b) = NEBuild1 a v b build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t) build n a (v :-> b) = NEBuild n a v b -fold1 :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) -fold1 (v1 :-> v2 :-> e1) e2 = NEFold1 v1 v2 e1 e2 +fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1i (v1 :-> v2 :-> e1) e2 = NEFold1Inner v1 v2 e1 e2 + +sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +sum1i e = NESum1Inner e unit :: NExpr env t -> NExpr env (TArr Z t) unit = NEUnit +replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t)) +replicate1i n a = NEReplicate1Inner n a + const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t) const_ x = let ty = knownScalTy @@ -72,9 +85,11 @@ idx0 = NEIdx0 (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) (.!) = NEIdx1 +infixl 9 .! -(!) :: SNat n -> NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t -(!) = NEIdx +(!) :: KnownNat n => NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t +(!) = NEIdx knownNat +infixl 9 ! shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) shape = NEShape @@ -85,20 +100,25 @@ oper = NEOp error_ :: KnownTy t => String -> NExpr env t error_ s = NEError knownTy s -(.==) :: KnownScalTy st => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) a .== b = oper (OEq knownScalTy) (pair a b) +infix 4 .== -(.<) :: KnownScalTy st => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) a .< b = oper (OLt knownScalTy) (pair a b) +infix 4 .< -(.>) :: KnownScalTy st => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) (.>) = flip (.<) +infix 4 .> -(.<=) :: KnownScalTy st => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) a .<= b = oper (OLe knownScalTy) (pair a b) +infix 4 .<= -(.>=) :: KnownScalTy st => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) (.>=) = flip (.<=) +infix 4 .>= not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) not_ = oper ONot diff --git a/src/Language/AST.hs b/src/Language/AST.hs index 511723a..af5a5a2 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -19,6 +19,7 @@ import Data.Type.Equality import GHC.OverloadedLabels import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(Text)) +import Array import AST import Data @@ -39,10 +40,13 @@ data NExpr env t where NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c -- array operations + NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) NEBuild1 :: NExpr env TIx -> Var name TIx -> NExpr ('(name, TIx) : env) t -> NExpr env (TArr (S Z) t) NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) - NEFold1 :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) + NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) + NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEUnit :: NExpr env t -> NExpr env (TArr Z t) + NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) -- expression operations NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) @@ -64,7 +68,7 @@ type family Lookup name env where data Var name t = Var (SSymbol name) (STy t) deriving (Show) -instance (t ~ TScal st, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where +instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where a + b = NEOp (OAdd knownScalTy) (NEPair a b) a * b = NEOp (OMul knownScalTy) (NEPair a b) negate e = NEOp (ONeg knownScalTy) e @@ -116,10 +120,13 @@ fromNamedExpr val = \case NEInr t e -> EInr ext t (go e) NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b) + NEConstArr n t x -> EConstArr ext n t x NEBuild1 a n b -> EBuild1 ext (go a) (lambda val n b) NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) - NEFold1 n1 n2 a b -> EFold1 ext (lambda2 val n1 n2 a) (go b) + NEFold1Inner n1 n2 a b -> EFold1Inner ext (lambda2 val n1 n2 a) (go b) + NESum1Inner e -> ESum1Inner ext (go e) NEUnit e -> EUnit ext (go e) + NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) NEConst t x -> EConst ext t x NEIdx0 e -> EIdx0 ext (go e) diff --git a/src/Simplify.hs b/src/Simplify.hs index c44d965..1640729 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -77,11 +77,13 @@ simplify' = \case EInl _ t e -> EInl ext t (simplify' e) EInr _ t e -> EInr ext t (simplify' e) ECase _ e a b -> ECase ext (simplify' e) (simplify' a) (simplify' b) + EConstArr _ n t v -> EConstArr ext n t v EBuild1 _ a b -> EBuild1 ext (simplify' a) (simplify' b) EBuild _ n a b -> EBuild ext n (simplify' a) (simplify' b) - EFold1 _ a b -> EFold1 ext (simplify' a) (simplify' b) + EFold1Inner _ a b -> EFold1Inner ext (simplify' a) (simplify' b) + ESum1Inner _ e -> ESum1Inner ext (simplify' e) EUnit _ e -> EUnit ext (simplify' e) - -- EReplicate _ e -> EReplicate ext (simplify' e) + EReplicate1Inner _ a b -> EReplicate1Inner ext (simplify' a) (simplify' b) EConst _ t v -> EConst ext t v EIdx0 _ e -> EIdx0 ext (simplify' e) EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b) @@ -112,11 +114,13 @@ hasAdds = \case EInl _ _ e -> hasAdds e EInr _ _ e -> hasAdds e ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b + EConstArr _ _ _ _ -> False EBuild1 _ a b -> hasAdds a || hasAdds b EBuild _ _ a b -> hasAdds a || hasAdds b - EFold1 _ a b -> hasAdds a || hasAdds b + EFold1Inner _ a b -> hasAdds a || hasAdds b + ESum1Inner _ e -> hasAdds e EUnit _ e -> hasAdds e - -- EReplicate _ e -> hasAdds e + EReplicate1Inner _ a b -> hasAdds a || hasAdds b EConst _ _ _ -> False EIdx0 _ e -> hasAdds e EIdx1 _ a b -> hasAdds a || hasAdds b |