summaryrefslogtreecommitdiff
path: root/src/AST/Count.hs
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-08-30 22:45:46 +0200
committerTom Smeding <tom@tomsmeding.com>2024-08-30 22:45:46 +0200
commit1f7ed2ee02222108684cfde8078e7a182f734a61 (patch)
tree976175ede4ec12a6e4a65d5e45e0b1ee8eeff5e6 /src/AST/Count.hs
parent172887fb577526de92b0653b5d3153114f8ce02a (diff)
WIP Build1
Diffstat (limited to 'src/AST/Count.hs')
-rw-r--r--src/AST/Count.hs106
1 files changed, 85 insertions, 21 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index f66b809..7e70a7d 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -1,10 +1,17 @@
+{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeOperators #-}
module AST.Count where
+import Data.Functor.Const
import GHC.Generics (Generic, Generically(..))
import AST
@@ -35,24 +42,81 @@ scaleMany :: Occ -> Occ
scaleMany (Occ l _) = Occ l Many
occCount :: Idx env a -> Expr x env t -> Occ
-occCount idx = \case
- EVar _ _ i | idx2int i == idx2int idx -> Occ One One
- | otherwise -> mempty
- ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body
- EPair _ a b -> occCount idx a <> occCount idx b
- EFst _ e -> occCount idx e
- ESnd _ e -> occCount idx e
- ENil _ -> mempty
- EInl _ _ e -> occCount idx e
- EInr _ _ e -> occCount idx e
- ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b)
- EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b)
- EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e)
- EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b
- EConst{} -> mempty
- EIdx1 _ a b -> occCount idx a <> occCount idx b
- EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es
- EOp _ _ e -> occCount idx e
- EWith a b -> occCount idx a <> occCount (IS idx) b
- EAccum a b e -> occCount idx a <> occCount idx b <> occCount idx e
- EError{} -> mempty
+occCount idx =
+ getConst . occCountGeneral
+ (\i o -> if idx2int i == idx2int idx then Const o else mempty)
+ (\(Const o) -> Const o)
+ (\_ (Const o) -> Const o)
+ (\(Const o1) (Const o2) -> Const (o1 <||> o2))
+ (\(Const o) -> Const (scaleMany o))
+
+
+data OccEnv env where
+ OccEnd :: OccEnv env -- not necessarily top!
+ OccPush :: OccEnv env -> Occ -> OccEnv (t : env)
+
+instance Semigroup (OccEnv env) where
+ OccEnd <> e = e
+ e <> OccEnd = e
+ OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o')
+
+instance Monoid (OccEnv env) where
+ mempty = OccEnd
+
+onehotOccEnv :: Idx env t -> Occ -> OccEnv env
+onehotOccEnv IZ v = OccPush OccEnd v
+onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty
+
+(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env
+OccEnd <||>! e = e
+e <||>! OccEnd = e
+OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o')
+
+scaleManyOccEnv :: OccEnv env -> OccEnv env
+scaleManyOccEnv OccEnd = OccEnd
+scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o)
+
+occCountAll :: Expr x env t -> OccEnv env
+occCountAll = occCountGeneral onehotOccEnv unpush unpushN (<||>!) scaleManyOccEnv
+ where
+ unpush :: OccEnv (t : env) -> OccEnv env
+ unpush (OccPush o _) = o
+ unpush OccEnd = OccEnd
+
+ unpushN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env
+ unpushN _ OccEnd = OccEnd
+ unpushN SZ e = e
+ unpushN (SS n) (OccPush e _) = unpushN n e
+
+occCountGeneral :: forall r env t x.
+ (forall env'. Monoid (r env'))
+ => (forall env' a. Idx env' a -> Occ -> r env') -- ^ one-hot
+ -> (forall env' a. r (a : env') -> r env') -- ^ unpush
+ -> (forall env' n. SNat n -> r (ConsN n TIx env') -> r env') -- ^ unpushN
+ -> (forall env'. r env' -> r env' -> r env') -- ^ alternation
+ -> (forall env'. r env' -> r env') -- ^ scale-many
+ -> Expr x env t -> r env
+occCountGeneral onehot unpush unpushN alter many = go
+ where
+ go :: Monoid (r env') => Expr x env' t' -> r env'
+ go = \case
+ EVar _ _ i -> onehot i (Occ One One)
+ ELet _ rhs body -> go rhs <> unpush (go body)
+ EPair _ a b -> go a <> go b
+ EFst _ e -> go e
+ ESnd _ e -> go e
+ ENil _ -> mempty
+ EInl _ _ e -> go e
+ EInr _ _ e -> go e
+ ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b))
+ EBuild1 _ a b -> go a <> many (unpush (go b))
+ EBuild _ es e -> foldMap go es <> many (unpushN (vecLength es) (go e))
+ EFold1 _ a b -> many (unpush (unpush (go a))) <> go b
+ EConst{} -> mempty
+ EIdx0 _ e -> go e
+ EIdx1 _ a b -> go a <> go b
+ EIdx _ e es -> go e <> foldMap go es
+ EOp _ _ e -> go e
+ EWith a b -> go a <> unpush (go b)
+ EAccum a b e -> go a <> go b <> go e
+ EError{} -> mempty