diff options
-rw-r--r-- | ox-arrays.cabal | 2 | ||||
-rw-r--r-- | src/Data/Array/Nested/Trace.hs | 65 | ||||
-rw-r--r-- | src/Data/Array/Nested/Trace/TH.hs | 82 |
3 files changed, 149 insertions, 0 deletions
diff --git a/ox-arrays.cabal b/ox-arrays.cabal index 99df88c..f114709 100644 --- a/ox-arrays.cabal +++ b/ox-arrays.cabal @@ -19,6 +19,8 @@ library Data.Array.Mixed.Types Data.Array.Mixed.XArray Data.Array.Nested + Data.Array.Nested.Trace + Data.Array.Nested.Trace.TH Data.Array.Nested.Internal.Convert Data.Array.Nested.Internal.Mixed Data.Array.Nested.Internal.Lemmas diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs new file mode 100644 index 0000000..eadfeeb --- /dev/null +++ b/src/Data/Array/Nested/Trace.hs @@ -0,0 +1,65 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ExplicitNamespaces #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TemplateHaskell #-} +{-| +This module is API-compatible with "Data.Array.Nested", except that inputs and +outputs of the methods are traced using 'Debug.Trace.trace'. Thus the methods +also have additional 'Show' constraints. + +>>> let res = rtranspose [1, 0] (rreshape (2 :$: 3 :$: ZSR) (riota @Int 6)) * rreshape (3 :$: 2 :$: ZSR) (rreplicate (6 :$: ZSR) (rscalar @Int 7)) +>>> length (show res) `seq` () +oxtrace: riota [Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5]))))] +oxtrace: rreshape [[2,3], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5]))))] +oxtrace: rtranspose [Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,3,1,4,2,5]))))] +oxtrace: rscalar [Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7]))))] +oxtrace: rreplicate [[6], Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7])))), Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7]))))] +oxtrace: rreshape [[3,2], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [7,7,7,7,7,7]))))] +>>> res +Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,21,7,28,14,35])))) +-} +module Data.Array.Nested.Trace ( + -- * Traced variants + module Data.Array.Nested.Trace, + + -- * Re-exports from the plain "Data.Array.Nested" module + Ranked(Ranked), + ListR(ZR, (:::)), + IxR(..), IIxR, + ShR(..), IShR, + + Shaped(Shaped), + ListS(ZS, (::$)), + IxS(..), IIxS, + ShS(..), KnownShS(..), + + Mixed, + IxX(..), IIxX, + KnownShX(..), StaticShX(..), + + Elt, + PrimElt, + Primitive(..), + KnownElt, + + type (++), + Storable, + SNat, pattern SNat, + pattern SZ, pattern SS, + Perm(..), + IsPermutation, + KnownPerm(..), + NumElt, FloatElt, +) where + +import Prelude hiding (mappend) + +import Data.Array.Nested +import Data.Array.Nested.Trace.TH + + +$(concat <$> mapM convertFun + ['rshape , 'rrank, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rtranspose, 'rappend, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromListOuter, 'rfromList1, 'rfromList1Prim, 'rtoListOuter, 'rtoList1, 'rslice, 'rrev1, 'rreshape, 'riota, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'sshape, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromListOuter, 'sfromList1, 'sfromList1Prim, 'stoListOuter, 'stoList1, 'sslice, 'srev1, 'sreshape, 'siota, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoRanked, 'mshape, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'mtranspose, 'mappend, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromListOuter, 'mfromList1, 'mfromList1Prim, 'mtoListOuter, 'mtoList1, 'mslice, 'mrev1, 'mreshape, 'miota, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mtoRanked, 'mcastToShaped]) diff --git a/src/Data/Array/Nested/Trace/TH.hs b/src/Data/Array/Nested/Trace/TH.hs new file mode 100644 index 0000000..47e53cd --- /dev/null +++ b/src/Data/Array/Nested/Trace/TH.hs @@ -0,0 +1,82 @@ +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TemplateHaskellQuotes #-} +module Data.Array.Nested.Trace.TH where + +import Control.Monad (zipWithM) +import Data.List (foldl', intersperse) +import Data.Maybe (isJust) +import Language.Haskell.TH hiding (cxt) + +import Debug.Trace qualified as Debug + +import Data.Array.Mixed.Types +import Data.Array.Nested + + +splitFunTy :: Type -> ([TyVarBndr Specificity], Cxt, [Type], Type) +splitFunTy = \case + ArrowT `AppT` t1 `AppT` t2 -> + let (vars, cx, args, ret) = splitFunTy t2 + in (vars, cx, t1 : args, ret) + ForallT vs cx' t -> + let (vars, cx, args, ret) = splitFunTy t + in (vars ++ vs, cx ++ cx', args, ret) + t -> ([], [], [], t) + +data Relevant = RRanked Type Type + | RShaped Type Type + | RMixed Type Type + | RShowable Type + deriving (Show) + +-- | If so, returns the element type +isRelevant :: Type -> Maybe Relevant +isRelevant (ConT name `AppT` sht `AppT` ty) + | name == ''Ranked = Just (RRanked sht ty) + | name == ''Shaped = Just (RShaped sht ty) + | name == ''Mixed = Just (RMixed sht ty) +isRelevant ty@(ConT name `AppT` _) + | name `elem` [''IShR, ''IIxR, ''ShS, ''IIxS, ''SNat] = + Just (RShowable ty) +isRelevant _ = Nothing + +convertType :: Type -> Q (Type, [Bool], Bool) +convertType typ = + let (tybndrs, cxt, args, ret) = splitFunTy typ + argrels = map isRelevant args + retrel = isRelevant ret + + showhead (RRanked n ty) = [ConT ''Mixed `AppT` (ConT ''Replicate `AppT` n `AppT` ConT 'Nothing) `AppT` ty] + showhead (RShaped sh ty) = [ConT ''Mixed `AppT` (ConT ''MapJust `AppT` sh) `AppT` ty] + showhead (RMixed sh ty) = [ConT ''Mixed `AppT` sh `AppT` ty] + showhead (RShowable _) = [] + in return + (ForallT tybndrs + (cxt ++ [ConT ''Show `AppT` hd + | Just rel <- retrel : argrels + , hd <- showhead rel]) + (foldr (\a b -> ArrowT `AppT` a `AppT` b) ret args) + ,map isJust argrels + ,isJust retrel) + +convertFun :: Name -> Q [Dec] +convertFun funname = do + defname <- newName (nameBase funname) + (convty, argarrs, retarr) <- reifyType funname >>= convertType + names <- zipWithM (\b i -> newName ((if b then "t" else "x") ++ show i)) argarrs [1::Int ..] + resname <- newName "res" + let tracenames = map fst (filter snd (zip (names ++ [resname]) (argarrs ++ [retarr]))) + let ex = LetE [ValD (VarP resname) + (NormalB (foldl' AppE (VarE funname) (map VarE names))) + []] + (VarE 'Debug.trace + `AppE` (VarE 'concat `AppE` ListE + ([LitE (StringL ("oxtrace: " ++ nameBase funname ++ " ["))] ++ + intersperse (LitE (StringL ", ")) + (map (\n -> VarE 'show `AppE` VarE n) tracenames) ++ + [LitE (StringL "]")])) + `AppE` VarE resname) + return + [SigD defname convty + ,FunD defname [Clause (map VarP names) (NormalB ex) []]] |