summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-03-07 17:20:56 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-03-07 17:20:56 +0100
commitfe2a29ee350c0078c845f4457d85941d41583142 (patch)
tree626562f389caeb08137ef9e67155375d0662bbe7
parentd588b0a245bb162ea60165eb2e42ef44e9b540bc (diff)
idana: Track array shapes
-rw-r--r--src/Analysis/Identity.hs87
-rw-r--r--src/Data.hs12
2 files changed, 70 insertions, 29 deletions
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index 285cfb8..7e10481 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -6,6 +6,9 @@ module Analysis.Identity (
identityAnalysis,
) where
+import Data.Foldable (toList)
+import Data.List (intercalate)
+
import AST
import AST.Pretty (PrettyX(..))
import CHAD.Types (d1, d2)
@@ -23,7 +26,7 @@ data ValId t where
VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case
VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a)
VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value
- VIArr :: Int -> ValId (TArr n t)
+ VIArr :: Int -> Vec n Int -> ValId (TArr n t)
VIScal :: Int -> ValId (TScal t)
VIAccum :: Int -> ValId (TAccum t)
@@ -44,7 +47,7 @@ instance Eq (ValId t) where
VIMaybe{} == _ = False
VIMaybe' a == VIMaybe' a' = a == a'
VIMaybe'{} == _ = False
- VIArr i == VIArr i' = i == i'
+ VIArr i is == VIArr i' is' = i == i' && is == is'
VIArr{} == _ = False
VIScal i == VIScal i' = i == i'
VIScal{} == _ = False
@@ -63,7 +66,7 @@ instance PrettyX ValId where
VIMaybe Nothing -> "N"
VIMaybe (Just a) -> 'J' : prettyX a
VIMaybe' a -> 'M' : prettyX a
- VIArr i -> 'A' : show i
+ VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]"
VIScal i -> show i
VIAccum i -> 'C' : show i
VIThing _ i -> '{' : show i ++ "}"
@@ -174,50 +177,55 @@ idana env expr = case expr of
pure (res, EMaybe res e1' e2' e3')
EConstArr _ dim t arr -> do
- x1 <- VIArr <$> genId
+ x1 <- VIArr <$> genId <*> vecReplicateA dim genId
pure (x1, EConstArr x1 dim t arr)
EBuild _ dim e1 e2 -> do
- (_, e1') <- idana env e1
+ (shids, e1') <- idana env e1
x1 <- genIds (tTup (sreplicate dim tIx))
(_, e2') <- idana (x1 `SCons` env) e2
- res <- VIArr <$> genId
+ res <- VIArr <$> genId <*> shidsToVec dim shids
pure (res, EBuild res dim e1' e2')
EFold1Inner _ e1 e2 e3 -> do
let t1 = typeOf e1
+ STArr dim _ = typeOf e3
x1 <- genIds t1
x2 <- genIds t1
(_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1
(_, e2') <- idana env e2
- (_, e3') <- idana env e3
- res <- VIArr <$> genId
+ (v3, e3') <- idana env e3
+ res <- VIArr <$> genId <*> (vecTail <$> valArrShape dim v3)
pure (res, EFold1Inner res e1' e2' e3')
ESum1Inner _ e1 -> do
- (_, e1') <- idana env e1
- res <- VIArr <$> genId
+ let STArr dim _ = typeOf e1
+ (v1, e1') <- idana env e1
+ res <- VIArr <$> genId <*> (vecTail <$> valArrShape dim v1)
pure (res, ESum1Inner res e1')
EUnit _ e1 -> do
(_, e1') <- idana env e1
- res <- VIArr <$> genId
+ res <- VIArr <$> genId <*> pure VNil
pure (res, EUnit res e1')
EReplicate1Inner _ e1 e2 -> do
- (_, e1') <- idana env e1
- (_, e2') <- idana env e2
- res <- VIArr <$> genId
+ let STArr dim _ = typeOf e2
+ (v1, e1') <- idana env e1
+ (v2, e2') <- idana env e2
+ res <- VIArr <$> genId <*> ((valScalId v1 :<) <$> valArrShape dim v2)
pure (res, EReplicate1Inner res e1' e2')
EMaximum1Inner _ e1 -> do
- (_, e1') <- idana env e1
- res <- VIArr <$> genId
+ let STArr dim _ = typeOf e1
+ (v1, e1') <- idana env e1
+ res <- VIArr <$> genId <*> (vecTail <$> valArrShape dim v1)
pure (res, EMaximum1Inner res e1')
EMinimum1Inner _ e1 -> do
- (_, e1') <- idana env e1
- res <- VIArr <$> genId
+ let STArr dim _ = typeOf e1
+ (v1, e1') <- idana env e1
+ res <- VIArr <$> genId <*> (vecTail <$> valArrShape dim v1)
pure (res, EMinimum1Inner res e1')
EConst _ t val -> do
@@ -230,9 +238,10 @@ idana env expr = case expr of
pure (res, EIdx0 res e1')
EIdx1 _ e1 e2 -> do
- (_, e1') <- idana env e1
+ let STArr dim _ = typeOf e1
+ (v1, e1') <- idana env e1
(_, e2') <- idana env e2
- res <- genIds (typeOf expr)
+ res <- VIArr <$> genId <*> (vecTail <$> valArrShape dim v1)
pure (res, EIdx1 res e1' e2')
EIdx _ e1 e2 -> do
@@ -242,8 +251,9 @@ idana env expr = case expr of
pure (res, EIdx res e1' e2')
EShape _ e1 -> do
- (_, e1') <- idana env e1
- res <- genIds (typeOf expr)
+ let STArr dim _ = typeOf e1
+ (v1, e1') <- idana env e1
+ res <- vecToShids dim <$> valArrShape dim v1
pure (res, EShape res e1')
EOp _ (op :: SOp a t) e1 -> do
@@ -324,22 +334,41 @@ unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b
unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a
unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b
unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b
-unify (VIArr i) (VIArr j) | i == j = pure $ VIArr i
- | otherwise = VIArr <$> genId
-unify (VIScal i) (VIScal j) | i == j = pure $ VIScal i
- | otherwise = VIScal <$> genId
-unify (VIAccum i) (VIAccum j) | i == j = pure $ VIAccum i
- | otherwise = VIAccum <$> genId
+unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js
+unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j
+unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j
unify (VIThing t i) (VIThing _ j) | i == j = pure $ VIThing t i
| otherwise = genIds t
unify (VIThing t _) _ = genIds t
unify _ (VIThing t _) = genIds t
+unifyID :: Int -> Int -> IdGen Int
+unifyID i j | i == j = pure i
+ | otherwise = genId
+
genIds :: STy t -> IdGen (ValId t)
genIds STNil = pure VINil
genIds (STPair a b) = VIPair <$> genIds a <*> genIds b
genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b
genIds (STMaybe t) = VIMaybe' <$> genIds t
-genIds STArr{} = VIArr <$> genId
+genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId
genIds STScal{} = VIScal <$> genId
genIds STAccum{} = VIAccum <$> genId
+
+shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int)
+shidsToVec SZ _ = pure VNil
+shidsToVec (SS n) (VIPair is (VIScal i)) = (i :<) <$> shidsToVec n is
+shidsToVec (SS n) (VIPair is (VIThing _ i)) = (i :<) <$> shidsToVec n is
+shidsToVec n VIThing{} = vecReplicateA n genId
+
+vecToShids :: SNat n -> Vec n Int -> ValId (Tup (Replicate n TIx))
+vecToShids SZ VNil = VINil
+vecToShids (SS n) (i :< is) = VIPair (vecToShids n is) (VIScal i)
+
+valScalId :: ValId (TScal t) -> Int
+valScalId (VIScal i) = i
+valScalId (VIThing _ i) = i
+
+valArrShape :: SNat n -> ValId (TArr n t) -> IdGen (Vec n Int)
+valArrShape _ (VIArr _ v) = pure v
+valArrShape n _ = vecReplicateA n genId
diff --git a/src/Data.hs b/src/Data.hs
index 155eeb3..d83c206 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -115,6 +115,7 @@ data Vec n t where
VNil :: Vec Z t
(:<) :: t -> Vec n t -> Vec (S n) t
deriving instance Show t => Show (Vec n t)
+deriving instance Eq t => Eq (Vec n t)
deriving instance Functor (Vec n)
deriving instance Foldable (Vec n)
deriving instance Traversable (Vec n)
@@ -130,6 +131,17 @@ vecGenerate = \n f -> go n f SZ
go SZ _ _ = VNil
go (SS n) f i = f i :< go n f (SS i)
+vecReplicateA :: Applicative f => SNat n -> f a -> f (Vec n a)
+vecReplicateA SZ _ = pure VNil
+vecReplicateA (SS n) gen = (:<) <$> gen <*> vecReplicateA n gen
+
+vecZipWithA :: Applicative f => (a -> b -> f c) -> Vec n a -> Vec n b -> f (Vec n c)
+vecZipWithA _ VNil VNil = pure VNil
+vecZipWithA f (x :< xs) (y :< ys) = (:<) <$> f x y <*> vecZipWithA f xs ys
+
+vecTail :: Vec (S n) a -> Vec n a
+vecTail (_ :< xs) = xs
+
unsafeCoerceRefl :: a :~: b
unsafeCoerceRefl = unsafeCoerce Refl