diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-09 22:08:17 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2025-03-09 22:08:17 +0100 |
commit | c3b4f56760547940256afea8e692681dbbe21857 (patch) | |
tree | 04e7aeee8ebbd78f937c7b4e34a08bec995beca9 | |
parent | da5dbc4ebca51a32b43bec360470c037cab1755f (diff) |
Clean up code organisation a little
-rw-r--r-- | src/AST.hs | 23 | ||||
-rw-r--r-- | src/AST/Types.hs | 25 | ||||
-rw-r--r-- | src/CHAD.hs | 23 | ||||
-rw-r--r-- | src/CHAD/Types.hs | 7 |
4 files changed, 37 insertions, 41 deletions
@@ -129,32 +129,9 @@ type Ex = Expr (Const ()) ext :: Const () a ext = Const () -type family Tup env where - Tup '[] = TNil - Tup (t : ts) = TPair (Tup ts) t - -mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b)) - -> SList f list -> f (Tup list) -mkTup nil _ SNil = nil -mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e - -tTup :: SList STy env -> STy (Tup env) -tTup = mkTup STNil STPair - eTup :: SList (Ex env) list -> Ex env (Tup list) eTup = mkTup (ENil ext) (EPair ext) -unTup :: (forall a b. c (TPair a b) -> (c a, c b)) - -> SList f list -> c (Tup list) -> SList c list -unTup _ SNil _ = SNil -unTup unpack (_ `SCons` list) tup = - let (xs, x) = unpack tup - in x `SCons` unTup unpack list xs - -type family InvTup core env where - InvTup core '[] = core - InvTup core (t : ts) = InvTup (TPair core t) ts - type SOp :: Ty -> Ty -> Type data SOp a t where OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) diff --git a/src/AST/Types.hs b/src/AST/Types.hs index acf7053..be7cffe 100644 --- a/src/AST/Types.hs +++ b/src/AST/Types.hs @@ -1,8 +1,10 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} module AST.Types where import Data.Int (Int32, Int64) @@ -117,3 +119,26 @@ hasArrays (STMaybe t) = hasArrays t hasArrays STArr{} = True hasArrays STScal{} = False hasArrays STAccum{} = True + +type family Tup env where + Tup '[] = TNil + Tup (t : ts) = TPair (Tup ts) t + +mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b)) + -> SList f list -> f (Tup list) +mkTup nil _ SNil = nil +mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e + +tTup :: SList STy env -> STy (Tup env) +tTup = mkTup STNil STPair + +unTup :: (forall a b. c (TPair a b) -> (c a, c b)) + -> SList f list -> c (Tup list) -> SList c list +unTup _ SNil _ = SNil +unTup unpack (_ `SCons` list) tup = + let (xs, x) = unpack tup + in x `SCons` unTup unpack list xs + +type family InvTup core env where + InvTup core '[] = core + InvTup core (t : ts) = InvTup (TPair core t) ts diff --git a/src/CHAD.hs b/src/CHAD.hs index 6118e48..3f76922 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -309,24 +309,11 @@ conv2Idx (DPush des (_, SDiscr)) (IS i) = conv2Idx DTop i = case i of {} ------------------------------------- LEMMAS ------------------------------------ - -indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) -indexTupD1Id SZ = Refl -indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl - - ------------------------------------ MONOIDS ----------------------------------- -zero :: STy t -> Ex env (D2 t) -zero = EZero ext - -plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t) -plus = EPlus ext - zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) zeroTup SNil = ENil ext -zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t) +zeroTup (SCons t env) = EPair ext (zeroTup env) (EZero ext t) ------------------------------------ SUBENVS ----------------------------------- @@ -366,7 +353,7 @@ subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k = ELet ext (weakenExpr WSink e2) $ EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) (EFst ext (EVar ext (typeOf e2) IZ))) - (plus t + (EPlus ext t (ESnd ext (EVar ext (typeOf e1) (IS IZ))) (ESnd ext (EVar ext (typeOf e2) IZ))) @@ -376,7 +363,7 @@ expandSubenvZeros (SCons t ts) (SEYes sub) e = ELet ext e $ let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) -expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t) +expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (EZero ext t) assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl @@ -664,7 +651,7 @@ drev des = \case subtape (EFst ext e1) sub - (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $ + (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (EZero ext t2))) $ weakenExpr (WCopy WSink) e2) ESnd _ e @@ -674,7 +661,7 @@ drev des = \case subtape (ESnd ext e1) sub - (ELet ext (EJust ext (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $ + (ELet ext (EJust ext (EPair ext (EZero ext t1) (EVar ext (d2 t2) IZ))) $ weakenExpr (WCopy WSink) e2) ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index fd1b6b1..a8614cf 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -99,3 +99,10 @@ chcSetAccum :: CHADConfig -> CHADConfig chcSetAccum c = c { chcLetArrayAccum = True , chcCaseArrayAccum = True , chcArgArrayAccum = True } + + +------------------------------------ LEMMAS ------------------------------------ + +indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) +indexTupD1Id SZ = Refl +indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl |