aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array/Nested/Trace/TH.hs
blob: 47e53cd0e9ce4fd99cbfbe2096965d4993d2fcab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
module Data.Array.Nested.Trace.TH where

import Control.Monad (zipWithM)
import Data.List (foldl', intersperse)
import Data.Maybe (isJust)
import Language.Haskell.TH hiding (cxt)

import Debug.Trace qualified as Debug

import Data.Array.Mixed.Types
import Data.Array.Nested


splitFunTy :: Type -> ([TyVarBndr Specificity], Cxt, [Type], Type)
splitFunTy = \case
  ArrowT `AppT` t1 `AppT` t2 ->
    let (vars, cx, args, ret) = splitFunTy t2
    in (vars, cx, t1 : args, ret)
  ForallT vs cx' t ->
    let (vars, cx, args, ret) = splitFunTy t
    in (vars ++ vs, cx ++ cx', args, ret)
  t -> ([], [], [], t)

data Relevant = RRanked Type Type
              | RShaped Type Type
              | RMixed Type Type
              | RShowable Type
  deriving (Show)

-- | If so, returns the element type
isRelevant :: Type -> Maybe Relevant
isRelevant (ConT name `AppT` sht `AppT` ty)
  | name == ''Ranked = Just (RRanked sht ty)
  | name == ''Shaped = Just (RShaped sht ty)
  | name == ''Mixed = Just (RMixed sht ty)
isRelevant ty@(ConT name `AppT` _)
  | name `elem` [''IShR, ''IIxR, ''ShS, ''IIxS, ''SNat] =
      Just (RShowable ty)
isRelevant _ = Nothing

convertType :: Type -> Q (Type, [Bool], Bool)
convertType typ =
  let (tybndrs, cxt, args, ret) = splitFunTy typ
      argrels = map isRelevant args
      retrel = isRelevant ret

      showhead (RRanked n ty) = [ConT ''Mixed `AppT` (ConT ''Replicate `AppT` n `AppT` ConT 'Nothing) `AppT` ty]
      showhead (RShaped sh ty) = [ConT ''Mixed `AppT` (ConT ''MapJust `AppT` sh) `AppT` ty]
      showhead (RMixed sh ty) = [ConT ''Mixed `AppT` sh `AppT` ty]
      showhead (RShowable _) = []
  in return
      (ForallT tybndrs
               (cxt ++ [ConT ''Show `AppT` hd
                       | Just rel <- retrel : argrels
                       , hd <- showhead rel])
               (foldr (\a b -> ArrowT `AppT` a `AppT` b) ret args)
      ,map isJust argrels
      ,isJust retrel)

convertFun :: Name -> Q [Dec]
convertFun funname = do
  defname <- newName (nameBase funname)
  (convty, argarrs, retarr) <- reifyType funname >>= convertType
  names <- zipWithM (\b i -> newName ((if b then "t" else "x") ++ show i)) argarrs [1::Int ..]
  resname <- newName "res"
  let tracenames = map fst (filter snd (zip (names ++ [resname]) (argarrs ++ [retarr])))
  let ex = LetE [ValD (VarP resname)
                      (NormalB (foldl' AppE (VarE funname) (map VarE names)))
                      []]
                (VarE 'Debug.trace
                  `AppE` (VarE 'concat `AppE` ListE
                            ([LitE (StringL ("oxtrace: " ++ nameBase funname ++ " ["))] ++
                             intersperse (LitE (StringL ", "))
                                         (map (\n -> VarE 'show `AppE` VarE n) tracenames) ++
                             [LitE (StringL "]")]))
                  `AppE` VarE resname)
  return
    [SigD defname convty
    ,FunD defname [Clause (map VarP names) (NormalB ex) []]]