aboutsummaryrefslogtreecommitdiff
path: root/src/Data/Array
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-06-09 11:25:29 +0200
committerTom Smeding <tom@tomsmeding.com>2024-06-09 11:42:13 +0200
commitc0ccb34d23e621a469460133fd9cf6e2223ed07a (patch)
tree28c836e110ed271fbe940347e21bdad22f30c783 /src/Data/Array
parent69982cb812156f9ed1ae136ec928a505495505db (diff)
Traced Nested module
Diffstat (limited to 'src/Data/Array')
-rw-r--r--src/Data/Array/Nested/Trace.hs65
-rw-r--r--src/Data/Array/Nested/Trace/TH.hs82
2 files changed, 147 insertions, 0 deletions
diff --git a/src/Data/Array/Nested/Trace.hs b/src/Data/Array/Nested/Trace.hs
new file mode 100644
index 0000000..eadfeeb
--- /dev/null
+++ b/src/Data/Array/Nested/Trace.hs
@@ -0,0 +1,65 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE ExplicitNamespaces #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-|
+This module is API-compatible with "Data.Array.Nested", except that inputs and
+outputs of the methods are traced using 'Debug.Trace.trace'. Thus the methods
+also have additional 'Show' constraints.
+
+>>> let res = rtranspose [1, 0] (rreshape (2 :$: 3 :$: ZSR) (riota @Int 6)) * rreshape (3 :$: 2 :$: ZSR) (rreplicate (6 :$: ZSR) (rscalar @Int 7))
+>>> length (show res) `seq` ()
+oxtrace: riota [Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5]))))]
+oxtrace: rreshape [[2,3], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5]))))]
+oxtrace: rtranspose [Ranked (M_Int (M_Primitive [2,3] (XArray (fromList [2,3] [0,1,2,3,4,5])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,3,1,4,2,5]))))]
+oxtrace: rscalar [Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7]))))]
+oxtrace: rreplicate [[6], Ranked (M_Int (M_Primitive [] (XArray (fromList [] [7])))), Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7]))))]
+oxtrace: rreshape [[3,2], Ranked (M_Int (M_Primitive [6] (XArray (fromList [6] [7,7,7,7,7,7])))), Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [7,7,7,7,7,7]))))]
+>>> res
+Ranked (M_Int (M_Primitive [3,2] (XArray (fromList [3,2] [0,21,7,28,14,35]))))
+-}
+module Data.Array.Nested.Trace (
+ -- * Traced variants
+ module Data.Array.Nested.Trace,
+
+ -- * Re-exports from the plain "Data.Array.Nested" module
+ Ranked(Ranked),
+ ListR(ZR, (:::)),
+ IxR(..), IIxR,
+ ShR(..), IShR,
+
+ Shaped(Shaped),
+ ListS(ZS, (::$)),
+ IxS(..), IIxS,
+ ShS(..), KnownShS(..),
+
+ Mixed,
+ IxX(..), IIxX,
+ KnownShX(..), StaticShX(..),
+
+ Elt,
+ PrimElt,
+ Primitive(..),
+ KnownElt,
+
+ type (++),
+ Storable,
+ SNat, pattern SNat,
+ pattern SZ, pattern SS,
+ Perm(..),
+ IsPermutation,
+ KnownPerm(..),
+ NumElt, FloatElt,
+) where
+
+import Prelude hiding (mappend)
+
+import Data.Array.Nested
+import Data.Array.Nested.Trace.TH
+
+
+$(concat <$> mapM convertFun
+ ['rshape , 'rrank, 'rindex, 'rindexPartial, 'rgenerate, 'rsumOuter1, 'rtranspose, 'rappend, 'rscalar, 'rfromVector, 'rtoVector, 'runScalar, 'rrerank, 'rreplicate, 'rreplicateScal, 'rfromListOuter, 'rfromList1, 'rfromList1Prim, 'rtoListOuter, 'rtoList1, 'rslice, 'rrev1, 'rreshape, 'riota, 'rlift, 'rlift2, 'rtoXArrayPrim, 'rfromXArrayPrim, 'rcastToShaped, 'rfromOrthotope, 'rtoOrthotope, 'sshape, 'sindex, 'sindexPartial, 'sgenerate, 'ssumOuter1, 'stranspose, 'sappend, 'sscalar, 'sfromVector, 'stoVector, 'sunScalar, 'srerank, 'sreplicate, 'sreplicateScal, 'sfromListOuter, 'sfromList1, 'sfromList1Prim, 'stoListOuter, 'stoList1, 'sslice, 'srev1, 'sreshape, 'siota, 'slift, 'slift2, 'stoXArrayPrim, 'sfromXArrayPrim, 'stoRanked, 'mshape, 'mindex, 'mindexPartial, 'mgenerate, 'msumOuter1, 'mtranspose, 'mappend, 'mscalar, 'mfromVector, 'mtoVector, 'munScalar, 'mrerank, 'mreplicate, 'mreplicateScal, 'mfromListOuter, 'mfromList1, 'mfromList1Prim, 'mtoListOuter, 'mtoList1, 'mslice, 'mrev1, 'mreshape, 'miota, 'mlift, 'mlift2, 'mtoXArrayPrim, 'mfromXArrayPrim, 'mtoRanked, 'mcastToShaped])
diff --git a/src/Data/Array/Nested/Trace/TH.hs b/src/Data/Array/Nested/Trace/TH.hs
new file mode 100644
index 0000000..47e53cd
--- /dev/null
+++ b/src/Data/Array/Nested/Trace/TH.hs
@@ -0,0 +1,82 @@
+{-# LANGUAGE ImportQualifiedPost #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TemplateHaskellQuotes #-}
+module Data.Array.Nested.Trace.TH where
+
+import Control.Monad (zipWithM)
+import Data.List (foldl', intersperse)
+import Data.Maybe (isJust)
+import Language.Haskell.TH hiding (cxt)
+
+import Debug.Trace qualified as Debug
+
+import Data.Array.Mixed.Types
+import Data.Array.Nested
+
+
+splitFunTy :: Type -> ([TyVarBndr Specificity], Cxt, [Type], Type)
+splitFunTy = \case
+ ArrowT `AppT` t1 `AppT` t2 ->
+ let (vars, cx, args, ret) = splitFunTy t2
+ in (vars, cx, t1 : args, ret)
+ ForallT vs cx' t ->
+ let (vars, cx, args, ret) = splitFunTy t
+ in (vars ++ vs, cx ++ cx', args, ret)
+ t -> ([], [], [], t)
+
+data Relevant = RRanked Type Type
+ | RShaped Type Type
+ | RMixed Type Type
+ | RShowable Type
+ deriving (Show)
+
+-- | If so, returns the element type
+isRelevant :: Type -> Maybe Relevant
+isRelevant (ConT name `AppT` sht `AppT` ty)
+ | name == ''Ranked = Just (RRanked sht ty)
+ | name == ''Shaped = Just (RShaped sht ty)
+ | name == ''Mixed = Just (RMixed sht ty)
+isRelevant ty@(ConT name `AppT` _)
+ | name `elem` [''IShR, ''IIxR, ''ShS, ''IIxS, ''SNat] =
+ Just (RShowable ty)
+isRelevant _ = Nothing
+
+convertType :: Type -> Q (Type, [Bool], Bool)
+convertType typ =
+ let (tybndrs, cxt, args, ret) = splitFunTy typ
+ argrels = map isRelevant args
+ retrel = isRelevant ret
+
+ showhead (RRanked n ty) = [ConT ''Mixed `AppT` (ConT ''Replicate `AppT` n `AppT` ConT 'Nothing) `AppT` ty]
+ showhead (RShaped sh ty) = [ConT ''Mixed `AppT` (ConT ''MapJust `AppT` sh) `AppT` ty]
+ showhead (RMixed sh ty) = [ConT ''Mixed `AppT` sh `AppT` ty]
+ showhead (RShowable _) = []
+ in return
+ (ForallT tybndrs
+ (cxt ++ [ConT ''Show `AppT` hd
+ | Just rel <- retrel : argrels
+ , hd <- showhead rel])
+ (foldr (\a b -> ArrowT `AppT` a `AppT` b) ret args)
+ ,map isJust argrels
+ ,isJust retrel)
+
+convertFun :: Name -> Q [Dec]
+convertFun funname = do
+ defname <- newName (nameBase funname)
+ (convty, argarrs, retarr) <- reifyType funname >>= convertType
+ names <- zipWithM (\b i -> newName ((if b then "t" else "x") ++ show i)) argarrs [1::Int ..]
+ resname <- newName "res"
+ let tracenames = map fst (filter snd (zip (names ++ [resname]) (argarrs ++ [retarr])))
+ let ex = LetE [ValD (VarP resname)
+ (NormalB (foldl' AppE (VarE funname) (map VarE names)))
+ []]
+ (VarE 'Debug.trace
+ `AppE` (VarE 'concat `AppE` ListE
+ ([LitE (StringL ("oxtrace: " ++ nameBase funname ++ " ["))] ++
+ intersperse (LitE (StringL ", "))
+ (map (\n -> VarE 'show `AppE` VarE n) tracenames) ++
+ [LitE (StringL "]")]))
+ `AppE` VarE resname)
+ return
+ [SigD defname convty
+ ,FunD defname [Clause (map VarP names) (NormalB ex) []]]