diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-08-30 22:45:46 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-08-30 22:45:46 +0200 |
commit | 1f7ed2ee02222108684cfde8078e7a182f734a61 (patch) | |
tree | 976175ede4ec12a6e4a65d5e45e0b1ee8eeff5e6 /src/AST/Count.hs | |
parent | 172887fb577526de92b0653b5d3153114f8ce02a (diff) |
WIP Build1
Diffstat (limited to 'src/AST/Count.hs')
-rw-r--r-- | src/AST/Count.hs | 106 |
1 files changed, 85 insertions, 21 deletions
diff --git a/src/AST/Count.hs b/src/AST/Count.hs index f66b809..7e70a7d 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -1,10 +1,17 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} module AST.Count where +import Data.Functor.Const import GHC.Generics (Generic, Generically(..)) import AST @@ -35,24 +42,81 @@ scaleMany :: Occ -> Occ scaleMany (Occ l _) = Occ l Many occCount :: Idx env a -> Expr x env t -> Occ -occCount idx = \case - EVar _ _ i | idx2int i == idx2int idx -> Occ One One - | otherwise -> mempty - ELet _ rhs body -> occCount idx rhs <> occCount (IS idx) body - EPair _ a b -> occCount idx a <> occCount idx b - EFst _ e -> occCount idx e - ESnd _ e -> occCount idx e - ENil _ -> mempty - EInl _ _ e -> occCount idx e - EInr _ _ e -> occCount idx e - ECase _ e a b -> occCount idx e <> (occCount (IS idx) a <||> occCount (IS idx) b) - EBuild1 _ a b -> occCount idx a <> scaleMany (occCount (IS idx) b) - EBuild _ es e -> foldMap (occCount idx) es <> scaleMany (occCount (wsinkN (vecLength es) @> idx) e) - EFold1 _ a b -> scaleMany (occCount (IS (IS idx)) a) <> occCount idx b - EConst{} -> mempty - EIdx1 _ a b -> occCount idx a <> occCount idx b - EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es - EOp _ _ e -> occCount idx e - EWith a b -> occCount idx a <> occCount (IS idx) b - EAccum a b e -> occCount idx a <> occCount idx b <> occCount idx e - EError{} -> mempty +occCount idx = + getConst . occCountGeneral + (\i o -> if idx2int i == idx2int idx then Const o else mempty) + (\(Const o) -> Const o) + (\_ (Const o) -> Const o) + (\(Const o1) (Const o2) -> Const (o1 <||> o2)) + (\(Const o) -> Const (scaleMany o)) + + +data OccEnv env where + OccEnd :: OccEnv env -- not necessarily top! + OccPush :: OccEnv env -> Occ -> OccEnv (t : env) + +instance Semigroup (OccEnv env) where + OccEnd <> e = e + e <> OccEnd = e + OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o') + +instance Monoid (OccEnv env) where + mempty = OccEnd + +onehotOccEnv :: Idx env t -> Occ -> OccEnv env +onehotOccEnv IZ v = OccPush OccEnd v +onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty + +(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env +OccEnd <||>! e = e +e <||>! OccEnd = e +OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o') + +scaleManyOccEnv :: OccEnv env -> OccEnv env +scaleManyOccEnv OccEnd = OccEnd +scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o) + +occCountAll :: Expr x env t -> OccEnv env +occCountAll = occCountGeneral onehotOccEnv unpush unpushN (<||>!) scaleManyOccEnv + where + unpush :: OccEnv (t : env) -> OccEnv env + unpush (OccPush o _) = o + unpush OccEnd = OccEnd + + unpushN :: SNat n -> OccEnv (ConsN n TIx env) -> OccEnv env + unpushN _ OccEnd = OccEnd + unpushN SZ e = e + unpushN (SS n) (OccPush e _) = unpushN n e + +occCountGeneral :: forall r env t x. + (forall env'. Monoid (r env')) + => (forall env' a. Idx env' a -> Occ -> r env') -- ^ one-hot + -> (forall env' a. r (a : env') -> r env') -- ^ unpush + -> (forall env' n. SNat n -> r (ConsN n TIx env') -> r env') -- ^ unpushN + -> (forall env'. r env' -> r env' -> r env') -- ^ alternation + -> (forall env'. r env' -> r env') -- ^ scale-many + -> Expr x env t -> r env +occCountGeneral onehot unpush unpushN alter many = go + where + go :: Monoid (r env') => Expr x env' t' -> r env' + go = \case + EVar _ _ i -> onehot i (Occ One One) + ELet _ rhs body -> go rhs <> unpush (go body) + EPair _ a b -> go a <> go b + EFst _ e -> go e + ESnd _ e -> go e + ENil _ -> mempty + EInl _ _ e -> go e + EInr _ _ e -> go e + ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b)) + EBuild1 _ a b -> go a <> many (unpush (go b)) + EBuild _ es e -> foldMap go es <> many (unpushN (vecLength es) (go e)) + EFold1 _ a b -> many (unpush (unpush (go a))) <> go b + EConst{} -> mempty + EIdx0 _ e -> go e + EIdx1 _ a b -> go a <> go b + EIdx _ e es -> go e <> foldMap go es + EOp _ _ e -> go e + EWith a b -> go a <> unpush (go b) + EAccum a b e -> go a <> go b <> go e + EError{} -> mempty |