summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chad-fast.cabal3
-rw-r--r--src/AST.hs49
-rw-r--r--src/AST/Count.hs6
-rw-r--r--src/AST/Pretty.hs29
-rw-r--r--src/Array.hs (renamed from src/Interpreter/Array.hs)26
-rw-r--r--src/CHAD.hs11
-rw-r--r--src/Example.hs20
-rw-r--r--src/Interpreter.hs152
-rw-r--r--src/Interpreter/Accum.hs76
-rw-r--r--src/Interpreter/Rep.hs17
-rw-r--r--src/Language.hs38
-rw-r--r--src/Language/AST.hs13
-rw-r--r--src/Simplify.hs12
13 files changed, 362 insertions, 90 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index ef1fd66..6314acd 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -10,6 +10,7 @@ build-type: Simple
library
exposed-modules:
+ Array
AST
AST.Count
AST.Env
@@ -22,7 +23,7 @@ library
Example
Interpreter
Interpreter.Accum
- Interpreter.Array
+ Interpreter.Rep
Language
Language.AST
Lemmas
diff --git a/src/AST.hs b/src/AST.hs
index 785e34a..2132bc6 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -20,6 +20,7 @@ import Data.Kind (Type)
import Data.Int
import Data.Type.Equality
+import Array
import AST.Env
import AST.Weaken
import Data
@@ -91,6 +92,13 @@ type family ScalRep t where
ScalRep TF64 = Double
ScalRep TBool = Bool
+type family ScalIsNumeric t where
+ ScalIsNumeric TI32 = True
+ ScalIsNumeric TI64 = True
+ ScalIsNumeric TF32 = True
+ ScalIsNumeric TF64 = True
+ ScalIsNumeric TBool = False
+
-- | This index is flipped around from the usual direction: the smallest index
-- is at the _heart_ of the nesting, not at the outside. The outermost layer
-- indexes into the _outer_ dimension of the type @t@. This makes indices into
@@ -128,11 +136,13 @@ data Expr x env t where
ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
-- array operations
+ EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t)
EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
- EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
+ EFold1Inner :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
+ ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
- -- EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) -- TODO: unused
+ EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)
-- expression operations
EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
@@ -160,13 +170,16 @@ type family Tup env where
Tup '[] = TNil
Tup (t : ts) = TPair (Tup ts) t
+mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b))
+ -> SList f list -> f (Tup list)
+mkTup nil _ SNil = nil
+mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e
+
tTup :: SList STy env -> STy (Tup env)
-tTup SNil = STNil
-tTup (SCons t ts) = STPair (tTup ts) t
+tTup = mkTup STNil STPair
eTup :: SList (Ex env) list -> Ex env (Tup list)
-eTup SNil = ENil ext
-eTup (e `SCons` es) = EPair ext (eTup es) e
+eTup = mkTup (ENil ext) (EPair ext)
type family InvTup core env where
InvTup core '[] = core
@@ -174,12 +187,12 @@ type family InvTup core env where
type SOp :: Ty -> Ty -> Type
data SOp a t where
- OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
- OMul :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
- ONeg :: SScalTy a -> SOp (TScal a) (TScal a)
- OLt :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
- OLe :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
- OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
+ OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
+ OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
+ ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
+ OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
+ OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
+ OEq :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
ONot :: SOp (TScal TBool) (TScal TBool)
OIf :: SOp (TScal TBool) (TEither TNil TNil)
deriving instance Show (SOp a t)
@@ -208,11 +221,13 @@ typeOf = \case
EInr _ t1 e -> STEither t1 (typeOf e)
ECase _ _ a _ -> typeOf a
+ EConstArr _ n t _ -> STArr n (STScal t)
EBuild1 _ _ e -> STArr (SS SZ) (typeOf e)
EBuild _ n _ e -> STArr n (typeOf e)
- EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ EFold1Inner _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
+ ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EUnit _ e -> STArr SZ (typeOf e)
- -- EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t
+ EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t
EConst _ t _ -> STScal t
EIdx0 _ e | STArr _ t <- typeOf e -> t
@@ -273,11 +288,13 @@ subst' f w = \case
EInl x t e -> EInl x t (subst' f w e)
EInr x t e -> EInr x t (subst' f w e)
ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
+ EConstArr x n t a -> EConstArr x n t a
EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)
EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
- EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
+ EFold1Inner x a b -> EFold1Inner x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
+ ESum1Inner x e -> ESum1Inner x (subst' f w e)
EUnit x e -> EUnit x (subst' f w e)
- -- EReplicate x e -> EReplicate x (subst' f w e)
+ EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (subst' f w e)
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 39d26c2..6a00e83 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -109,11 +109,13 @@ occCountGeneral onehot unpush alter many = go WId
EInl _ _ e -> re e
EInr _ _ e -> re e
ECase _ e a b -> re e <> (re1 a `alter` re1 b)
+ EConstArr{} -> mempty
EBuild1 _ a b -> re a <> many (re1 b)
EBuild _ _ a b -> re a <> many (re1 b)
- EFold1 _ a b -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b
+ EFold1Inner _ a b -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b
+ ESum1Inner _ e -> re e
EUnit _ e -> re e
- -- EReplicate _ e -> re e
+ EReplicate1Inner _ a b -> re a <> re b
EConst{} -> mempty
EIdx0 _ e -> re e
EIdx1 _ a b -> re a <> re b
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index bf0d350..2ce883b 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -42,11 +42,6 @@ genNameIfUsedIn' prefix ty idx ex
genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
genNameIfUsedIn = genNameIfUsedIn' "x"
-valprj :: SList f env -> Idx env t -> f t
-valprj (x `SCons` _) IZ = x
-valprj (_ `SCons` env) (IS i) = valprj env i
-valprj SNil i = case i of {}
-
ppExpr :: SList STy env -> Expr x env t -> String
ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""
where
@@ -59,7 +54,7 @@ ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""
ppExpr' :: Int -> SVal env -> Expr x env t -> M ShowS
ppExpr' d val = \case
- EVar _ _ i -> return $ showString $ getConst $ valprj val i
+ EVar _ _ i -> return $ showString $ getConst $ slistIdx val i
e@ELet{} -> ppExprLet d val e
@@ -97,6 +92,9 @@ ppExpr' d val = \case
showString "case " . e' . showString (" of { Inl " ++ name1 ++ " -> ") . a'
. showString (" ; Inr " ++ name2 ++ " -> ") . b' . showString " }"
+ EConstArr _ _ ty v
+ | Dict <- scalRepIsShow ty -> return $ showsPrec d v
+
EBuild1 _ a b -> do
a' <- ppExpr' 11 val a
name <- genNameIfUsedIn (STScal STI64) IZ b
@@ -111,25 +109,30 @@ ppExpr' d val = \case
return $ showParen (d > 10) $
showString "build " . a' . showString (" (\\" ++ name ++ " -> ") . e' . showString ")"
- EFold1 _ a b -> do
+ EFold1Inner _ a b -> do
name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
name2 <- genNameIfUsedIn (typeOf a) IZ a
a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
b' <- ppExpr' 11 val b
return $ showParen (d > 10) $
- showString ("fold1 (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a'
+ showString ("fold1i (\\" ++ name1 ++ " " ++ name2 ++ " -> ") . a'
. showString ") " . b'
+ ESum1Inner _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ showParen (d > 10) $ showString "sum1i " . e'
+
EUnit _ e -> do
e' <- ppExpr' 11 val e
return $ showParen (d > 10) $ showString "unit " . e'
- -- EReplicate _ e -> do
- -- e' <- ppExpr' 11 val e
- -- return $ showParen (d > 10) $ showString "replicate " . e'
+ EReplicate1Inner _ a b -> do
+ a' <- ppExpr' 11 val a
+ b' <- ppExpr' 11 val b
+ return $ showParen (d > 10) $ showString "replicate1i " . a' . showString " " . b'
- EConst _ ty v -> return $ showString $ case ty of
- STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v
+ EConst _ ty v
+ | Dict <- scalRepIsShow ty -> return $ showsPrec d v
EIdx0 _ e -> do
e' <- ppExpr' 11 val e
diff --git a/src/Interpreter/Array.hs b/src/Array.hs
index 54e0791..9a770c4 100644
--- a/src/Interpreter/Array.hs
+++ b/src/Array.hs
@@ -1,8 +1,9 @@
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
-module Interpreter.Array where
+module Array where
import Control.Monad.Trans.State.Strict
import Data.Foldable (traverse_)
@@ -15,18 +16,32 @@ import Data
data Shape n where
ShNil :: Shape Z
ShCons :: Shape n -> Int -> Shape (S n)
+deriving instance Show (Shape n)
data Index n where
IxNil :: Index Z
IxCons :: Index n -> Int -> Index (S n)
+deriving instance Show (Index n)
shapeSize :: Shape n -> Int
shapeSize ShNil = 0
shapeSize (ShCons sh n) = shapeSize sh * n
+fromLinearIndex :: Shape n -> Int -> Index n
+fromLinearIndex ShNil 0 = IxNil
+fromLinearIndex ShNil _ = error "Index out of range"
+fromLinearIndex (sh `ShCons` n) i =
+ let (q, r) = i `quotRem` n
+ in fromLinearIndex sh q `IxCons` r
+
+toLinearIndex :: Shape n -> Index n -> Int
+toLinearIndex ShNil IxNil = 0
+toLinearIndex (sh `ShCons` n) (idx `IxCons` i) = toLinearIndex sh idx * n + i
+
-- | TODO: this Vector is a boxed vector, which is horrendously inefficient.
data Array (n :: Nat) t = Array (Shape n) (Vector t)
+ deriving (Show)
arrayShape :: Array n t -> Shape n
arrayShape (Array sh _) = sh
@@ -34,9 +49,18 @@ arrayShape (Array sh _) = sh
arraySize :: Array n t -> Int
arraySize (Array sh _) = shapeSize sh
+arrayIndex :: Array n t -> Index n -> t
+arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx)
+
arrayIndexLinear :: Array n t -> Int -> t
arrayIndexLinear (Array _ v) i = v V.! i
+arrayIndex1 :: Array (S n) t -> Int -> Array n t
+arrayIndex1 (Array (sh `ShCons` _) v) i = let sz = shapeSize sh in Array sh (V.slice (sz * i) sz v)
+
+arrayGenerateM :: Monad m => Shape n -> (Index n -> m t) -> m (Array n t)
+arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh)
+
arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t)
arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 692bb96..943f0a2 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -968,6 +968,12 @@ drev des = \case
(subenvNone (select SMerge des))
(ENil ext)
+ EConstArr _ n t val ->
+ Ret BTop
+ (EConstArr ext n t val)
+ (subenvNone (select SMerge des))
+ (ENil ext)
+
EBuild1 _ ne (orige :: Ex _ eltty)
| Ret (ne0 :: Bindings _ _ ne_binds) ne1 _ _ <- drev des ne -- allowed to ignore ne2 here because ne has a discrete result
, let eltty = typeOf orige ->
@@ -1075,9 +1081,10 @@ drev des = \case
weakenExpr (WCopy (WSink .> WSink)) e2)
-- These should be the next to be implemented, I think
- EFold1{} -> err_unsupported "EFold1"
+ ESum1Inner{} -> err_unsupported "ESum"
+ EReplicate1Inner{} -> err_unsupported "EReplicate"
EShape{} -> err_unsupported "EShape"
- -- EReplicate{} -> err_unsupported "EReplicate"
+ EFold1Inner{} -> err_unsupported "EFold1Inner"
EIdx{} -> err_unsupported "EIdx"
EBuild{} -> err_unsupported "EBuild"
diff --git a/src/Example.hs b/src/Example.hs
index 4130f47..d1d04e3 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -140,3 +140,23 @@ ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $
in let_ #parstup #pars123 $
let_ #inp #input $
layer (STPair (STPair (STPair STNil tpair) tpair) tpair)
+
+type TVec = TArr (S Z)
+type TMat = TArr (S (S Z))
+
+neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R
+neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $
+ let layer :: (Lookup "wei" env ~ TMat R, Lookup "bias" env ~ TVec R, Lookup "x" env ~ TVec R) => NExpr env (TVec R)
+ layer =
+ -- prod = wei `matmul` x
+ let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :->
+ #wei ! #idx * #x ! pair nil (snd_ #idx)) $
+ -- relu (prod + bias)
+ build (SS SZ) (shape #prod) $ #idx :->
+ let_ #out (#prod ! #idx + #bias ! #idx) $
+ if_ (#out .<= const_ 0) (const_ 0) #out
+
+ in let_ #x1 (let_ #wei (fst_ #layer1) $ let_ #bias (snd_ #layer1) $ let_ #x #input $ layer) $
+ let_ #x2 (let_ #wei (fst_ #layer2) $ let_ #bias (snd_ #layer2) $ let_ #x #x1 $ layer) $
+ let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $
+ #x3 ! nil
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index afc50f9..7ffb14c 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -1,8 +1,156 @@
-module Interpreter where
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE FlexibleContexts #-}
+module Interpreter (
+ interpret,
+ interpret',
+ Value,
+ NoAccum(..),
+ unAccum,
+) where
+
+import Data.Int (Int64)
+import Data.Proxy
import AST
-import Interpreter.Array
+import Data
+import Array
import Interpreter.Accum
+import Interpreter.Rep
+import Control.Monad (foldM)
+
+
+interpret :: NoAccum t => Ex '[] t -> Rep t
+interpret e = runAcM (go e)
+ where
+ go :: forall s t. NoAccum t => Ex '[] t -> AcM s (Rep t)
+ go e' | Refl <- noAccum (Proxy @s) (Proxy @t) = interpret' SNil e'
+
+newtype Value s t = Value (Rep' s t)
+
+interpret' :: forall env t s. SList (Value s) env -> Ex env t -> AcM s (Rep' s t)
+interpret' env = \case
+ EVar _ _ i -> case slistIdx env i of Value x -> return x
+ ELet _ a b -> do
+ x <- interpret' env a
+ interpret' (Value x `SCons` env) b
+ EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b
+ EFst _ e -> fst <$> interpret' env e
+ ESnd _ e -> snd <$> interpret' env e
+ ENil _ -> return ()
+ EInl _ _ e -> Left <$> interpret' env e
+ EInr _ _ e -> Right <$> interpret' env e
+ ECase _ e a b -> interpret' env e >>= \case
+ Left x -> interpret' (Value x `SCons` env) a
+ Right y -> interpret' (Value y `SCons` env) b
+ EConstArr _ _ _ v -> return v
+ EBuild1 _ a b -> do
+ n <- fromIntegral @Int64 @Int <$> interpret' env a
+ arrayGenerateLinM (ShNil `ShCons` n)
+ (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b)
+ EBuild _ dim a b -> do
+ sh <- unTupRepIdx (Proxy @s) ShNil ShCons dim <$> interpret' env a
+ arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx (Proxy @s) ixUncons dim idx) `SCons` env) b)
+ EFold1Inner _ a b -> do
+ let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a
+ arr <- interpret' env b
+ let sh `ShCons` n = arrayShape arr
+ arrayGenerateM sh $ \idx -> foldl1M f [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ ESum1Inner _ e -> do
+ arr <- interpret' env e
+ let STArr _ (STScal t) = typeOf e
+ sh `ShCons` n = arrayShape arr
+ numericIsNum t $ arrayGenerateM sh $ \idx -> return $ sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
+ EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e)
+ EReplicate1Inner _ a b -> do
+ n <- fromIntegral @Int64 @Int <$> interpret' env a
+ arr <- interpret' env b
+ let sh = arrayShape arr
+ arrayGenerateM (sh `ShCons` n) (\(idx `IxCons` _) -> return (arrayIndex arr idx))
+ EConst _ _ v -> return v
+ EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
+ EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
+ EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx (Proxy @s) IxNil IxCons n <$> interpret' env b)
+ EShape _ e | STArr n _ <- typeOf e -> tupRepIdx (Proxy @s) shUncons n . arrayShape <$> interpret' env e
+ EOp _ op e -> interpretOp (Proxy @s) op <$> interpret' env e
+ EWith e1 e2 -> do
+ initval <- interpret' env e1
+ withAccum (typeOf e1) initval $ \accum ->
+ interpret' (Value accum `SCons` env) e2
+ EAccum i e1 e2 e3 -> do
+ idx <- interpret' env e1
+ val <- interpret' env e2
+ accum <- interpret' env e3
+ accumAdd accum i idx val
+ EError _ s -> error $ "Interpreter: Program threw error: " ++ s
+
+interpretOp :: Proxy s -> SOp a t -> Rep' s a -> Rep' s t
+interpretOp _ op arg = case op of
+ OAdd st -> numericIsNum st $ uncurry (+) arg
+ OMul st -> numericIsNum st $ uncurry (*) arg
+ ONeg st -> numericIsNum st $ negate arg
+ OLt st -> numericIsNum st $ uncurry (<) arg
+ OLe st -> numericIsNum st $ uncurry (<=) arg
+ OEq st -> numericIsNum st $ uncurry (==) arg
+ ONot -> not arg
+ OIf -> if arg then Left () else Right ()
+
+numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
+numericIsNum STI32 = id
+numericIsNum STI64 = id
+numericIsNum STF32 = id
+numericIsNum STF64 = id
+
+unTupRepIdx :: Proxy s -> f Z -> (forall m. f m -> Int -> f (S m))
+ -> SNat n -> Rep' s (Tup (Replicate n TIx)) -> f n
+unTupRepIdx _ nil _ SZ _ = nil
+unTupRepIdx p nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i
+
+tupRepIdx :: Proxy s -> (forall m. f (S m) -> (f m, Int))
+ -> SNat n -> f n -> Rep' s (Tup (Replicate n TIx))
+tupRepIdx _ _ SZ _ = ()
+tupRepIdx p uncons (SS n) tup =
+ let (tup', i) = uncons tup
+ in (tupRepIdx p uncons n tup', fromIntegral @Int @Int64 i)
+
+ixUncons :: Index (S n) -> (Index n, Int)
+ixUncons (IxCons idx i) = (idx, i)
+
+shUncons :: Shape (S n) -> (Shape n, Int)
+shUncons (ShCons idx i) = (idx, i)
+class NoAccum t where
+ noAccum :: Proxy s -> Proxy t -> Rep' s t :~: Rep t
+instance NoAccum TNil where
+ noAccum _ _ = Refl
+instance (NoAccum a, NoAccum b) => NoAccum (TPair a b) where
+ noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl
+instance (NoAccum a, NoAccum b) => NoAccum (TEither a b) where
+ noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl
+instance NoAccum t => NoAccum (TArr n t) where
+ noAccum p _ | Refl <- noAccum p (Proxy @t) = Refl
+instance NoAccum (TScal t) where
+ noAccum _ _ = Refl
+unAccum :: Proxy s -> STy t -> Maybe (Dict (NoAccum t))
+unAccum _ STNil = Just Dict
+unAccum p (STPair t1 t2)
+ | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict
+ | otherwise = Nothing
+unAccum p (STEither t1 t2)
+ | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict
+ | otherwise = Nothing
+unAccum p (STArr _ t)
+ | Just Dict <- unAccum p t = Just Dict
+ | otherwise = Nothing
+unAccum _ STScal{} = Just Dict
+unAccum _ STAccum{} = Nothing
+foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a
+foldl1M _ [] = error "foldl1M: empty list"
+foldl1M f (tophead : toptail) = foldM f tophead toptail
diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs
index d15ea10..b6a91df 100644
--- a/src/Interpreter/Accum.hs
+++ b/src/Interpreter/Accum.hs
@@ -15,7 +15,7 @@
module Interpreter.Accum (
AcM,
runAcM,
- Rep,
+ Rep',
Accum,
withAccum,
accumAdd,
@@ -25,6 +25,8 @@ module Interpreter.Accum (
import Control.Concurrent
import Control.Monad (when, forM_)
import Data.Bifunctor (second)
+import Data.Proxy
+import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr)
import Foreign.Storable (sizeOf)
import GHC.Exts
import GHC.Float
@@ -33,10 +35,9 @@ import GHC.IO (IO(..))
import GHC.Word
import System.IO.Unsafe (unsafePerformIO)
+import Array
import AST
import Data
-import Interpreter.Array
-import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr)
newtype AcM s a = AcM (IO a)
@@ -45,34 +46,35 @@ newtype AcM s a = AcM (IO a)
runAcM :: (forall s. AcM s a) -> a
runAcM (AcM m) = unsafePerformIO m
-type family Rep t where
- Rep TNil = ()
- Rep (TPair a b) = (Rep a, Rep b)
- Rep (TEither a b) = Either (Rep a) (Rep b)
- Rep (TArr n t) = Array n (Rep t)
- Rep (TScal sty) = ScalRep sty
- -- Rep (TAccum t) = _
+-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined.
+type family Rep' s t where
+ Rep' s TNil = ()
+ Rep' s (TPair a b) = (Rep' s a, Rep' s b)
+ Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b)
+ Rep' s (TArr n t) = Array n (Rep' s t)
+ Rep' s (TScal sty) = ScalRep sty
+ Rep' s (TAccum t) = Accum s t
-- | Floats and integers are accumulated; booleans are left as-is.
data Accum s t = Accum (STy t) (ForeignPtr ())
-tSize :: STy t -> Rep t -> Int
-tSize ty x = tSize' ty (Just x)
+tSize :: Proxy s -> STy t -> Rep' s t -> Int
+tSize p ty x = tSize' p ty (Just x)
-- | Passing Nothing as the value means "this is (inside) an array element".
-tSize' :: STy t -> Maybe (Rep t) -> Int
-tSize' typ val = case typ of
+tSize' :: Proxy s -> STy t -> Maybe (Rep' s t) -> Int
+tSize' p typ val = case typ of
STNil -> 0
- STPair a b -> tSize' a (fst <$> val) + tSize' b (snd <$> val)
+ STPair a b -> tSize' p a (fst <$> val) + tSize' p b (snd <$> val)
STEither a b ->
case val of
- Nothing -> 1 + max (tSize' a Nothing) (tSize' b Nothing)
- Just (Left x) -> 1 + tSize a x -- '1 +' is for runtime sanity checking
- Just (Right y) -> 1 + tSize b y -- idem
+ Nothing -> 1 + max (tSize' p a Nothing) (tSize' p b Nothing)
+ Just (Left x) -> 1 + tSize' p a (Just x) -- '1 +' is for runtime sanity checking
+ Just (Right y) -> 1 + tSize' p b (Just y) -- idem
STArr ndim t ->
case val of
Nothing -> error "Nested arrays not supported in this implementation"
- Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' t Nothing
+ Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing
STScal sty -> goScal sty
STAccum{} -> error "Nested accumulators unsupported"
where
@@ -86,10 +88,10 @@ tSize' typ val = case typ of
-- | This operation does not commute with 'accumAdd', so it must be used with
-- care. Furthermore it must be used on exactly the same value as tSize was
-- called on. Hence it lives in IO, not in AcM.
-accumWrite :: forall s t. Accum s t -> Rep t -> IO ()
+accumWrite :: forall s t. Accum s t -> Rep' s t -> IO ()
accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
let
- go :: Bool -> STy t' -> Rep t' -> Int -> IO Int
+ go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
go inarr ty val off = case ty of
STNil -> return off
STPair a b -> do
@@ -110,7 +112,7 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
| inarr -> error "Nested arrays not supported in this implementation"
| otherwise -> do
off1 <- goShape (arrayShape val) off
- let eltsize = tSize' t Nothing
+ let eltsize = tSize' (Proxy @s) t Nothing
n = arraySize val
traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val
return (off1 + eltsize * n)
@@ -136,10 +138,10 @@ accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) ->
in () <$ go False topty top_value 0
-accumRead :: forall s t. Accum s t -> AcM s (Rep t)
+accumRead :: forall s t. Accum s t -> AcM s (Rep' s t)
accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
let
- go :: Bool -> STy t' -> Int -> IO (Int, Rep t')
+ go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t')
go inarr ty off = case ty of
STNil -> return (off, ())
STPair a b -> do
@@ -154,13 +156,13 @@ accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
1 -> fmap Right <$> go inarr b (off + 1)
_ -> error "Invalid tag in accum memory"
if inarr
- then return (off + 1 + max (tSize' a Nothing) (tSize' b Nothing), val)
+ then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val)
else return (off1, val)
STArr ndim t
| inarr -> error "Nested arrays not supported in this implementation"
| otherwise -> do
(off1, sh) <- readShape addr# ndim off
- let eltsize = tSize' t Nothing
+ let eltsize = tSize' (Proxy @s) t Nothing
n = shapeSize sh
arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini))
return (off1 + eltsize * n, arr)
@@ -205,10 +207,10 @@ invertShape | Refl <- lemPlusZero @n = flip go IShNil
go ShNil ish = ish
go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish)
-accumAdd :: forall s t i. Accum s t -> SNat i -> Rep (AcIdx t i) -> Rep (AcVal t i) -> AcM s ()
+accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s ()
accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) ->
let
- go :: Bool -> STy t' -> SNat i' -> Rep (AcIdx t' i') -> Rep (AcVal t' i') -> Int -> IO ()
+ go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO ()
go inarr ty SZ () val off = () <$ performAdd inarr ty val off
go inarr ty (SS dep) idx val off = case (ty, idx, val) of
(STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off
@@ -227,23 +229,23 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr
(STAccum{}, _, _) -> error "Nested accumulators unsupported"
goArr :: SNat i' -> InvShape n -> STy t'
- -> Rep (AcIdx (TArr n t') i') -> Rep (AcVal (TArr n t') i') -> Int -> IO ()
+ -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO ()
goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off
goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off
goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do
- let i' = fromIntegral @(Rep TIx) @Int i
+ let i' = fromIntegral @(Rep' s TIx) @Int i
when (i' < 0 || i' >= n) $
error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")"
goArr depm1 ish t1 idx val (off + i' * ishSize ish)
- performAddArr :: Int -> STy t' -> Array n (Rep t') -> Int -> IO Int
+ performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int
performAddArr arraySz eltty val off = do
- let eltsize = tSize' eltty Nothing
+ let eltsize = tSize' (Proxy @s) eltty Nothing
forM_ [0 .. arraySz - 1] $ \lini ->
performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize)
return (off + arraySz * eltsize)
- performAdd :: Bool -> STy t' -> Rep t' -> Int -> IO Int
+ performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int
performAdd inarr ty val off = case ty of
STNil -> return off
STPair t1 t2 -> do
@@ -257,7 +259,7 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr
(Right val2, 1) -> performAdd inarr t2 val2 (off + 1)
_ -> error "accumAdd: Tag mismatch for Either"
if inarr
- then return (off + 1 + max (tSize' t1 Nothing) (tSize' t2 Nothing))
+ then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing))
else return off1
STArr n ty'
| inarr -> error "Nested array"
@@ -300,18 +302,18 @@ accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr
in () <$ go False topty top_depth top_index top_value 0
-withAccum :: STy t -> Rep t -> (Accum s t -> AcM s b) -> AcM s (Rep t, b)
+withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t)
withAccum ty start fun = do
-- The initial write must happen before any of the adds or reads, so it makes
-- sense to put it in IO together with the allocation, instead of in AcM.
- accum <- AcM $ do buffer <- mallocBytes (tSize ty start)
+ accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start)
ptr <- newForeignPtr finalizerFree buffer
let accum = Accum ty ptr
accumWrite accum start
return accum
b <- fun accum
out <- accumRead accum
- return (out, b)
+ return (b, out)
inParallel :: [AcM s t] -> AcM s [t]
inParallel actions = AcM $ do
diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs
new file mode 100644
index 0000000..1ded773
--- /dev/null
+++ b/src/Interpreter/Rep.hs
@@ -0,0 +1,17 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+module Interpreter.Rep where
+
+import GHC.TypeError
+
+import Array
+import AST
+
+
+type family Rep t where
+ Rep TNil = ()
+ Rep (TPair a b) = (Rep a, Rep b)
+ Rep (TEither a b) = Either (Rep a) (Rep b)
+ Rep (TArr n t) = Array n (Rep t)
+ Rep (TScal sty) = ScalRep sty
+ Rep (TAccum t) = TypeError (Text "Accumulator in Rep")
diff --git a/src/Language.hs b/src/Language.hs
index 58a7070..cdc6d6b 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -8,6 +8,7 @@ module Language (
Lookup,
) where
+import Array
import AST
import Data
import Language.AST
@@ -49,18 +50,30 @@ inr = NEInr
case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c
case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2
+constArr_ :: (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
+constArr_ x =
+ let ty = knownScalTy
+ in case scalRepIsShow ty of
+ Dict -> NEConstArr knownNat ty x
+
build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t)
build1 a (v :-> b) = NEBuild1 a v b
build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t)
build n a (v :-> b) = NEBuild n a v b
-fold1 :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
-fold1 (v1 :-> v2 :-> e1) e2 = NEFold1 v1 v2 e1 e2
+fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
+fold1i (v1 :-> v2 :-> e1) e2 = NEFold1Inner v1 v2 e1 e2
+
+sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
+sum1i e = NESum1Inner e
unit :: NExpr env t -> NExpr env (TArr Z t)
unit = NEUnit
+replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t))
+replicate1i n a = NEReplicate1Inner n a
+
const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t)
const_ x =
let ty = knownScalTy
@@ -72,9 +85,11 @@ idx0 = NEIdx0
(.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t)
(.!) = NEIdx1
+infixl 9 .!
-(!) :: SNat n -> NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
-(!) = NEIdx
+(!) :: KnownNat n => NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t
+(!) = NEIdx knownNat
+infixl 9 !
shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx))
shape = NEShape
@@ -85,20 +100,25 @@ oper = NEOp
error_ :: KnownTy t => String -> NExpr env t
error_ s = NEError knownTy s
-(.==) :: KnownScalTy st => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
+(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .== b = oper (OEq knownScalTy) (pair a b)
+infix 4 .==
-(.<) :: KnownScalTy st => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
+(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .< b = oper (OLt knownScalTy) (pair a b)
+infix 4 .<
-(.>) :: KnownScalTy st => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
+(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
(.>) = flip (.<)
+infix 4 .>
-(.<=) :: KnownScalTy st => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
+(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
a .<= b = oper (OLe knownScalTy) (pair a b)
+infix 4 .<=
-(.>=) :: KnownScalTy st => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
+(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
(.>=) = flip (.<=)
+infix 4 .>=
not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool)
not_ = oper ONot
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 511723a..af5a5a2 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -19,6 +19,7 @@ import Data.Type.Equality
import GHC.OverloadedLabels
import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(Text))
+import Array
import AST
import Data
@@ -39,10 +40,13 @@ data NExpr env t where
NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c
-- array operations
+ NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t))
NEBuild1 :: NExpr env TIx -> Var name TIx -> NExpr ('(name, TIx) : env) t -> NExpr env (TArr (S Z) t)
NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t)
- NEFold1 :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
+ NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t)
+ NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t))
NEUnit :: NExpr env t -> NExpr env (TArr Z t)
+ NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t)
-- expression operations
NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t)
@@ -64,7 +68,7 @@ type family Lookup name env where
data Var name t = Var (SSymbol name) (STy t)
deriving (Show)
-instance (t ~ TScal st, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where
+instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where
a + b = NEOp (OAdd knownScalTy) (NEPair a b)
a * b = NEOp (OMul knownScalTy) (NEPair a b)
negate e = NEOp (ONeg knownScalTy) e
@@ -116,10 +120,13 @@ fromNamedExpr val = \case
NEInr t e -> EInr ext t (go e)
NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b)
+ NEConstArr n t x -> EConstArr ext n t x
NEBuild1 a n b -> EBuild1 ext (go a) (lambda val n b)
NEBuild k a n b -> EBuild ext k (go a) (lambda val n b)
- NEFold1 n1 n2 a b -> EFold1 ext (lambda2 val n1 n2 a) (go b)
+ NEFold1Inner n1 n2 a b -> EFold1Inner ext (lambda2 val n1 n2 a) (go b)
+ NESum1Inner e -> ESum1Inner ext (go e)
NEUnit e -> EUnit ext (go e)
+ NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b)
NEConst t x -> EConst ext t x
NEIdx0 e -> EIdx0 ext (go e)
diff --git a/src/Simplify.hs b/src/Simplify.hs
index c44d965..1640729 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -77,11 +77,13 @@ simplify' = \case
EInl _ t e -> EInl ext t (simplify' e)
EInr _ t e -> EInr ext t (simplify' e)
ECase _ e a b -> ECase ext (simplify' e) (simplify' a) (simplify' b)
+ EConstArr _ n t v -> EConstArr ext n t v
EBuild1 _ a b -> EBuild1 ext (simplify' a) (simplify' b)
EBuild _ n a b -> EBuild ext n (simplify' a) (simplify' b)
- EFold1 _ a b -> EFold1 ext (simplify' a) (simplify' b)
+ EFold1Inner _ a b -> EFold1Inner ext (simplify' a) (simplify' b)
+ ESum1Inner _ e -> ESum1Inner ext (simplify' e)
EUnit _ e -> EUnit ext (simplify' e)
- -- EReplicate _ e -> EReplicate ext (simplify' e)
+ EReplicate1Inner _ a b -> EReplicate1Inner ext (simplify' a) (simplify' b)
EConst _ t v -> EConst ext t v
EIdx0 _ e -> EIdx0 ext (simplify' e)
EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b)
@@ -112,11 +114,13 @@ hasAdds = \case
EInl _ _ e -> hasAdds e
EInr _ _ e -> hasAdds e
ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b
+ EConstArr _ _ _ _ -> False
EBuild1 _ a b -> hasAdds a || hasAdds b
EBuild _ _ a b -> hasAdds a || hasAdds b
- EFold1 _ a b -> hasAdds a || hasAdds b
+ EFold1Inner _ a b -> hasAdds a || hasAdds b
+ ESum1Inner _ e -> hasAdds e
EUnit _ e -> hasAdds e
- -- EReplicate _ e -> hasAdds e
+ EReplicate1Inner _ a b -> hasAdds a || hasAdds b
EConst _ _ _ -> False
EIdx0 _ e -> hasAdds e
EIdx1 _ a b -> hasAdds a || hasAdds b