diff options
Diffstat (limited to 'src/Data/Array/Nested/Trace/TH.hs')
-rw-r--r-- | src/Data/Array/Nested/Trace/TH.hs | 82 |
1 files changed, 82 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Trace/TH.hs b/src/Data/Array/Nested/Trace/TH.hs new file mode 100644 index 0000000..47e53cd --- /dev/null +++ b/src/Data/Array/Nested/Trace/TH.hs @@ -0,0 +1,82 @@ +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TemplateHaskellQuotes #-} +module Data.Array.Nested.Trace.TH where + +import Control.Monad (zipWithM) +import Data.List (foldl', intersperse) +import Data.Maybe (isJust) +import Language.Haskell.TH hiding (cxt) + +import Debug.Trace qualified as Debug + +import Data.Array.Mixed.Types +import Data.Array.Nested + + +splitFunTy :: Type -> ([TyVarBndr Specificity], Cxt, [Type], Type) +splitFunTy = \case + ArrowT `AppT` t1 `AppT` t2 -> + let (vars, cx, args, ret) = splitFunTy t2 + in (vars, cx, t1 : args, ret) + ForallT vs cx' t -> + let (vars, cx, args, ret) = splitFunTy t + in (vars ++ vs, cx ++ cx', args, ret) + t -> ([], [], [], t) + +data Relevant = RRanked Type Type + | RShaped Type Type + | RMixed Type Type + | RShowable Type + deriving (Show) + +-- | If so, returns the element type +isRelevant :: Type -> Maybe Relevant +isRelevant (ConT name `AppT` sht `AppT` ty) + | name == ''Ranked = Just (RRanked sht ty) + | name == ''Shaped = Just (RShaped sht ty) + | name == ''Mixed = Just (RMixed sht ty) +isRelevant ty@(ConT name `AppT` _) + | name `elem` [''IShR, ''IIxR, ''ShS, ''IIxS, ''SNat] = + Just (RShowable ty) +isRelevant _ = Nothing + +convertType :: Type -> Q (Type, [Bool], Bool) +convertType typ = + let (tybndrs, cxt, args, ret) = splitFunTy typ + argrels = map isRelevant args + retrel = isRelevant ret + + showhead (RRanked n ty) = [ConT ''Mixed `AppT` (ConT ''Replicate `AppT` n `AppT` ConT 'Nothing) `AppT` ty] + showhead (RShaped sh ty) = [ConT ''Mixed `AppT` (ConT ''MapJust `AppT` sh) `AppT` ty] + showhead (RMixed sh ty) = [ConT ''Mixed `AppT` sh `AppT` ty] + showhead (RShowable _) = [] + in return + (ForallT tybndrs + (cxt ++ [ConT ''Show `AppT` hd + | Just rel <- retrel : argrels + , hd <- showhead rel]) + (foldr (\a b -> ArrowT `AppT` a `AppT` b) ret args) + ,map isJust argrels + ,isJust retrel) + +convertFun :: Name -> Q [Dec] +convertFun funname = do + defname <- newName (nameBase funname) + (convty, argarrs, retarr) <- reifyType funname >>= convertType + names <- zipWithM (\b i -> newName ((if b then "t" else "x") ++ show i)) argarrs [1::Int ..] + resname <- newName "res" + let tracenames = map fst (filter snd (zip (names ++ [resname]) (argarrs ++ [retarr]))) + let ex = LetE [ValD (VarP resname) + (NormalB (foldl' AppE (VarE funname) (map VarE names))) + []] + (VarE 'Debug.trace + `AppE` (VarE 'concat `AppE` ListE + ([LitE (StringL ("oxtrace: " ++ nameBase funname ++ " ["))] ++ + intersperse (LitE (StringL ", ")) + (map (\n -> VarE 'show `AppE` VarE n) tracenames) ++ + [LitE (StringL "]")])) + `AppE` VarE resname) + return + [SigD defname convty + ,FunD defname [Clause (map VarP names) (NormalB ex) []]] |