summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2025-03-09 22:08:17 +0100
committerTom Smeding <t.j.smeding@uu.nl>2025-03-09 22:08:17 +0100
commitc3b4f56760547940256afea8e692681dbbe21857 (patch)
tree04e7aeee8ebbd78f937c7b4e34a08bec995beca9
parentda5dbc4ebca51a32b43bec360470c037cab1755f (diff)
Clean up code organisation a little
-rw-r--r--src/AST.hs23
-rw-r--r--src/AST/Types.hs25
-rw-r--r--src/CHAD.hs23
-rw-r--r--src/CHAD/Types.hs7
4 files changed, 37 insertions, 41 deletions
diff --git a/src/AST.hs b/src/AST.hs
index e22d11f..0e040d4 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -129,32 +129,9 @@ type Ex = Expr (Const ())
ext :: Const () a
ext = Const ()
-type family Tup env where
- Tup '[] = TNil
- Tup (t : ts) = TPair (Tup ts) t
-
-mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b))
- -> SList f list -> f (Tup list)
-mkTup nil _ SNil = nil
-mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e
-
-tTup :: SList STy env -> STy (Tup env)
-tTup = mkTup STNil STPair
-
eTup :: SList (Ex env) list -> Ex env (Tup list)
eTup = mkTup (ENil ext) (EPair ext)
-unTup :: (forall a b. c (TPair a b) -> (c a, c b))
- -> SList f list -> c (Tup list) -> SList c list
-unTup _ SNil _ = SNil
-unTup unpack (_ `SCons` list) tup =
- let (xs, x) = unpack tup
- in x `SCons` unTup unpack list xs
-
-type family InvTup core env where
- InvTup core '[] = core
- InvTup core (t : ts) = InvTup (TPair core t) ts
-
type SOp :: Ty -> Ty -> Type
data SOp a t where
OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index acf7053..be7cffe 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -1,8 +1,10 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
module AST.Types where
import Data.Int (Int32, Int64)
@@ -117,3 +119,26 @@ hasArrays (STMaybe t) = hasArrays t
hasArrays STArr{} = True
hasArrays STScal{} = False
hasArrays STAccum{} = True
+
+type family Tup env where
+ Tup '[] = TNil
+ Tup (t : ts) = TPair (Tup ts) t
+
+mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b))
+ -> SList f list -> f (Tup list)
+mkTup nil _ SNil = nil
+mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e
+
+tTup :: SList STy env -> STy (Tup env)
+tTup = mkTup STNil STPair
+
+unTup :: (forall a b. c (TPair a b) -> (c a, c b))
+ -> SList f list -> c (Tup list) -> SList c list
+unTup _ SNil _ = SNil
+unTup unpack (_ `SCons` list) tup =
+ let (xs, x) = unpack tup
+ in x `SCons` unTup unpack list xs
+
+type family InvTup core env where
+ InvTup core '[] = core
+ InvTup core (t : ts) = InvTup (TPair core t) ts
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 6118e48..3f76922 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -309,24 +309,11 @@ conv2Idx (DPush des (_, SDiscr)) (IS i) =
conv2Idx DTop i = case i of {}
------------------------------------- LEMMAS ------------------------------------
-
-indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
-indexTupD1Id SZ = Refl
-indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl
-
-
------------------------------------ MONOIDS -----------------------------------
-zero :: STy t -> Ex env (D2 t)
-zero = EZero ext
-
-plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
-plus = EPlus ext
-
zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
zeroTup SNil = ENil ext
-zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t)
+zeroTup (SCons t env) = EPair ext (zeroTup env) (EZero ext t)
------------------------------------ SUBENVS -----------------------------------
@@ -366,7 +353,7 @@ subenvPlus (SCons t env) (SEYes sub1) (SEYes sub2) k =
ELet ext (weakenExpr WSink e2) $
EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ)))
(EFst ext (EVar ext (typeOf e2) IZ)))
- (plus t
+ (EPlus ext t
(ESnd ext (EVar ext (typeOf e1) (IS IZ)))
(ESnd ext (EVar ext (typeOf e2) IZ)))
@@ -376,7 +363,7 @@ expandSubenvZeros (SCons t ts) (SEYes sub) e =
ELet ext e $
let var = EVar ext (STPair (tTup (d2e (subList ts sub))) (d2 t)) IZ
in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var)
-expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t)
+expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (EZero ext t)
assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[]
assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
@@ -664,7 +651,7 @@ drev des = \case
subtape
(EFst ext e1)
sub
- (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $
+ (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (EZero ext t2))) $
weakenExpr (WCopy WSink) e2)
ESnd _ e
@@ -674,7 +661,7 @@ drev des = \case
subtape
(ESnd ext e1)
sub
- (ELet ext (EJust ext (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $
+ (ELet ext (EJust ext (EPair ext (EZero ext t1) (EVar ext (d2 t2) IZ))) $
weakenExpr (WCopy WSink) e2)
ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index fd1b6b1..a8614cf 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -99,3 +99,10 @@ chcSetAccum :: CHADConfig -> CHADConfig
chcSetAccum c = c { chcLetArrayAccum = True
, chcCaseArrayAccum = True
, chcArgArrayAccum = True }
+
+
+------------------------------------ LEMMAS ------------------------------------
+
+indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx))
+indexTupD1Id SZ = Refl
+indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl