diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-07 17:20:56 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-07 17:20:56 +0100 |
commit | fe2a29ee350c0078c845f4457d85941d41583142 (patch) | |
tree | 626562f389caeb08137ef9e67155375d0662bbe7 | |
parent | d588b0a245bb162ea60165eb2e42ef44e9b540bc (diff) |
idana: Track array shapes
-rw-r--r-- | src/Analysis/Identity.hs | 87 | ||||
-rw-r--r-- | src/Data.hs | 12 |
2 files changed, 70 insertions, 29 deletions
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 285cfb8..7e10481 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -6,6 +6,9 @@ module Analysis.Identity ( identityAnalysis, ) where +import Data.Foldable (toList) +import Data.List (intercalate) + import AST import AST.Pretty (PrettyX(..)) import CHAD.Types (d1, d2) @@ -23,7 +26,7 @@ data ValId t where VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a) VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value - VIArr :: Int -> ValId (TArr n t) + VIArr :: Int -> Vec n Int -> ValId (TArr n t) VIScal :: Int -> ValId (TScal t) VIAccum :: Int -> ValId (TAccum t) @@ -44,7 +47,7 @@ instance Eq (ValId t) where VIMaybe{} == _ = False VIMaybe' a == VIMaybe' a' = a == a' VIMaybe'{} == _ = False - VIArr i == VIArr i' = i == i' + VIArr i is == VIArr i' is' = i == i' && is == is' VIArr{} == _ = False VIScal i == VIScal i' = i == i' VIScal{} == _ = False @@ -63,7 +66,7 @@ instance PrettyX ValId where VIMaybe Nothing -> "N" VIMaybe (Just a) -> 'J' : prettyX a VIMaybe' a -> 'M' : prettyX a - VIArr i -> 'A' : show i + VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]" VIScal i -> show i VIAccum i -> 'C' : show i VIThing _ i -> '{' : show i ++ "}" @@ -174,50 +177,55 @@ idana env expr = case expr of pure (res, EMaybe res e1' e2' e3') EConstArr _ dim t arr -> do - x1 <- VIArr <$> genId + x1 <- VIArr <$> genId <*> vecReplicateA dim genId pure (x1, EConstArr x1 dim t arr) EBuild _ dim e1 e2 -> do - (_, e1') <- idana env e1 + (shids, e1') <- idana env e1 x1 <- genIds (tTup (sreplicate dim tIx)) (_, e2') <- idana (x1 `SCons` env) e2 - res <- VIArr <$> genId + res <- VIArr <$> genId <*> shidsToVec dim shids pure (res, EBuild res dim e1' e2') EFold1Inner _ e1 e2 e3 -> do let t1 = typeOf e1 + STArr dim _ = typeOf e3 x1 <- genIds t1 x2 <- genIds t1 (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1 (_, e2') <- idana env e2 - (_, e3') <- idana env e3 - res <- VIArr <$> genId + (v3, e3') <- idana env e3 + res <- VIArr <$> genId <*> (vecTail <$> valArrShape dim v3) pure (res, EFold1Inner res e1' e2' e3') ESum1Inner _ e1 -> do - (_, e1') <- idana env e1 - res <- VIArr <$> genId + let STArr dim _ = typeOf e1 + (v1, e1') <- idana env e1 + res <- VIArr <$> genId <*> (vecTail <$> valArrShape dim v1) pure (res, ESum1Inner res e1') EUnit _ e1 -> do (_, e1') <- idana env e1 - res <- VIArr <$> genId + res <- VIArr <$> genId <*> pure VNil pure (res, EUnit res e1') EReplicate1Inner _ e1 e2 -> do - (_, e1') <- idana env e1 - (_, e2') <- idana env e2 - res <- VIArr <$> genId + let STArr dim _ = typeOf e2 + (v1, e1') <- idana env e1 + (v2, e2') <- idana env e2 + res <- VIArr <$> genId <*> ((valScalId v1 :<) <$> valArrShape dim v2) pure (res, EReplicate1Inner res e1' e2') EMaximum1Inner _ e1 -> do - (_, e1') <- idana env e1 - res <- VIArr <$> genId + let STArr dim _ = typeOf e1 + (v1, e1') <- idana env e1 + res <- VIArr <$> genId <*> (vecTail <$> valArrShape dim v1) pure (res, EMaximum1Inner res e1') EMinimum1Inner _ e1 -> do - (_, e1') <- idana env e1 - res <- VIArr <$> genId + let STArr dim _ = typeOf e1 + (v1, e1') <- idana env e1 + res <- VIArr <$> genId <*> (vecTail <$> valArrShape dim v1) pure (res, EMinimum1Inner res e1') EConst _ t val -> do @@ -230,9 +238,10 @@ idana env expr = case expr of pure (res, EIdx0 res e1') EIdx1 _ e1 e2 -> do - (_, e1') <- idana env e1 + let STArr dim _ = typeOf e1 + (v1, e1') <- idana env e1 (_, e2') <- idana env e2 - res <- genIds (typeOf expr) + res <- VIArr <$> genId <*> (vecTail <$> valArrShape dim v1) pure (res, EIdx1 res e1' e2') EIdx _ e1 e2 -> do @@ -242,8 +251,9 @@ idana env expr = case expr of pure (res, EIdx res e1' e2') EShape _ e1 -> do - (_, e1') <- idana env e1 - res <- genIds (typeOf expr) + let STArr dim _ = typeOf e1 + (v1, e1') <- idana env e1 + res <- vecToShids dim <$> valArrShape dim v1 pure (res, EShape res e1') EOp _ (op :: SOp a t) e1 -> do @@ -324,22 +334,41 @@ unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b -unify (VIArr i) (VIArr j) | i == j = pure $ VIArr i - | otherwise = VIArr <$> genId -unify (VIScal i) (VIScal j) | i == j = pure $ VIScal i - | otherwise = VIScal <$> genId -unify (VIAccum i) (VIAccum j) | i == j = pure $ VIAccum i - | otherwise = VIAccum <$> genId +unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js +unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j +unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j unify (VIThing t i) (VIThing _ j) | i == j = pure $ VIThing t i | otherwise = genIds t unify (VIThing t _) _ = genIds t unify _ (VIThing t _) = genIds t +unifyID :: Int -> Int -> IdGen Int +unifyID i j | i == j = pure i + | otherwise = genId + genIds :: STy t -> IdGen (ValId t) genIds STNil = pure VINil genIds (STPair a b) = VIPair <$> genIds a <*> genIds b genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b genIds (STMaybe t) = VIMaybe' <$> genIds t -genIds STArr{} = VIArr <$> genId +genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId genIds STScal{} = VIScal <$> genId genIds STAccum{} = VIAccum <$> genId + +shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int) +shidsToVec SZ _ = pure VNil +shidsToVec (SS n) (VIPair is (VIScal i)) = (i :<) <$> shidsToVec n is +shidsToVec (SS n) (VIPair is (VIThing _ i)) = (i :<) <$> shidsToVec n is +shidsToVec n VIThing{} = vecReplicateA n genId + +vecToShids :: SNat n -> Vec n Int -> ValId (Tup (Replicate n TIx)) +vecToShids SZ VNil = VINil +vecToShids (SS n) (i :< is) = VIPair (vecToShids n is) (VIScal i) + +valScalId :: ValId (TScal t) -> Int +valScalId (VIScal i) = i +valScalId (VIThing _ i) = i + +valArrShape :: SNat n -> ValId (TArr n t) -> IdGen (Vec n Int) +valArrShape _ (VIArr _ v) = pure v +valArrShape n _ = vecReplicateA n genId diff --git a/src/Data.hs b/src/Data.hs index 155eeb3..d83c206 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -115,6 +115,7 @@ data Vec n t where VNil :: Vec Z t (:<) :: t -> Vec n t -> Vec (S n) t deriving instance Show t => Show (Vec n t) +deriving instance Eq t => Eq (Vec n t) deriving instance Functor (Vec n) deriving instance Foldable (Vec n) deriving instance Traversable (Vec n) @@ -130,6 +131,17 @@ vecGenerate = \n f -> go n f SZ go SZ _ _ = VNil go (SS n) f i = f i :< go n f (SS i) +vecReplicateA :: Applicative f => SNat n -> f a -> f (Vec n a) +vecReplicateA SZ _ = pure VNil +vecReplicateA (SS n) gen = (:<) <$> gen <*> vecReplicateA n gen + +vecZipWithA :: Applicative f => (a -> b -> f c) -> Vec n a -> Vec n b -> f (Vec n c) +vecZipWithA _ VNil VNil = pure VNil +vecZipWithA f (x :< xs) (y :< ys) = (:<) <$> f x y <*> vecZipWithA f xs ys + +vecTail :: Vec (S n) a -> Vec n a +vecTail (_ :< xs) = xs + unsafeCoerceRefl :: a :~: b unsafeCoerceRefl = unsafeCoerce Refl |