From 174af2ba568de66e0d890825b8bda930b8e7bb96 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 10 Nov 2025 21:49:45 +0100 Subject: Move module hierarchy under CHAD. --- src/CHAD/Interpreter/Rep.hs | 105 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 src/CHAD/Interpreter/Rep.hs (limited to 'src/CHAD/Interpreter/Rep.hs') diff --git a/src/CHAD/Interpreter/Rep.hs b/src/CHAD/Interpreter/Rep.hs new file mode 100644 index 0000000..fadc6be --- /dev/null +++ b/src/CHAD/Interpreter/Rep.hs @@ -0,0 +1,105 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +module CHAD.Interpreter.Rep where + +import Control.DeepSeq +import Data.Coerce (coerce) +import Data.List (intersperse, intercalate) +import Data.Foldable (toList) +import Data.IORef +import GHC.Exts (withDict) + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty +import CHAD.Data + + +type family Rep t where + Rep TNil = () + Rep (TPair a b) = (Rep a, Rep b) + Rep (TEither a b) = Either (Rep a) (Rep b) + Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b)) + Rep (TMaybe t) = Maybe (Rep t) + Rep (TArr n t) = Array n (Rep t) + Rep (TScal sty) = ScalRep sty + Rep (TAccum t) = RepAc t + +-- Mutable, represents monoid types t. +type family RepAc t where + RepAc TNil = () + RepAc (TPair a b) = (RepAc a, RepAc b) + RepAc (TLEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b))) + RepAc (TMaybe t) = IORef (Maybe (RepAc t)) + RepAc (TArr n t) = Array n (RepAc t) + RepAc (TScal sty) = IORef (ScalRep sty) + +newtype Value t = Value { unValue :: Rep t } + +liftV :: (Rep a -> Rep b) -> Value a -> Value b +liftV f (Value x) = Value (f x) + +liftV2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c +liftV2 f (Value x) (Value y) = Value (f x y) + +vPair :: Value a -> Value b -> Value (TPair a b) +vPair = liftV2 (,) + +vUnpair :: Value (TPair a b) -> (Value a, Value b) +vUnpair (Value (x, y)) = (Value x, Value y) + +showValue :: Int -> STy t -> Rep t -> ShowS +showValue _ STNil () = showString "()" +showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")" +showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Inl " . showValue 11 a x +showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Inr " . showValue 11 b y +showValue _ (STLEither _ _) Nothing = showString "LNil" +showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x +showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y +showValue _ (STMaybe _) Nothing = showString "Nothing" +showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x +showValue d (STArr _ t) arr = showParen (d > 10) $ + showString "arrayFromList " . showsPrec 11 (arrayShape arr) + . showString " [" + . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr)) + . showString "]" +showValue d (STScal sty) x = case sty of + STF32 -> showsPrec d x + STF64 -> showsPrec d x + STI32 -> showsPrec d x + STI64 -> showsPrec d x + STBool -> showsPrec d x +showValue _ (STAccum t) _ = showString $ "" + +showEnv :: SList STy env -> SList Value env -> String +showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" + where + showEntries :: SList STy env -> SList Value env -> [String] + showEntries SNil SNil = [] + showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs + +rnfRep :: STy t -> Rep t -> () +rnfRep STNil () = () +rnfRep (STPair a b) (x, y) = rnfRep a x `seq` rnfRep b y +rnfRep (STEither a _) (Left x) = rnfRep a x +rnfRep (STEither _ b) (Right y) = rnfRep b y +rnfRep (STLEither _ _) Nothing = () +rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x +rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y +rnfRep (STMaybe _) Nothing = () +rnfRep (STMaybe t) (Just x) = rnfRep t x +rnfRep (STArr (_ :: SNat n) (t :: STy t2)) arr = + withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr) +rnfRep (STScal t) x = case t of + STI32 -> rnf x + STI64 -> rnf x + STF32 -> rnf x + STF64 -> rnf x + STBool -> rnf x +rnfRep STAccum{} _ = error "Cannot rnf accumulators" + +instance KnownTy t => NFData (Value t) where + rnf (Value x) = rnfRep (knownTy @t) x -- cgit v1.2.3-70-g09d2