diff options
Diffstat (limited to 'src/CHAD/Types.hs')
-rw-r--r-- | src/CHAD/Types.hs | 83 |
1 files changed, 64 insertions, 19 deletions
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs index e8ec0c9..44ac20e 100644 --- a/src/CHAD/Types.hs +++ b/src/CHAD/Types.hs @@ -1,8 +1,10 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module CHAD.Types where +import AST.Accum import AST.Types import Data @@ -11,14 +13,16 @@ type family D1 t where D1 TNil = TNil D1 (TPair a b) = TPair (D1 a) (D1 b) D1 (TEither a b) = TEither (D1 a) (D1 b) + D1 (TLEither a b) = TLEither (D1 a) (D1 b) D1 (TMaybe a) = TMaybe (D1 a) D1 (TArr n t) = TArr n (D1 t) D1 (TScal t) = TScal t type family D2 t where D2 TNil = TNil - D2 (TPair a b) = TMaybe (TPair (D2 a) (D2 b)) - D2 (TEither a b) = TMaybe (TEither (D2 a) (D2 b)) + D2 (TPair a b) = TPair (D2 a) (D2 b) + D2 (TEither a b) = TLEither (D2 a) (D2 b) + D2 (TLEither a b) = TLEither (D2 a) (D2 b) D2 (TMaybe t) = TMaybe (D2 t) D2 (TArr n t) = TArr n (D2 t) D2 (TScal t) = D2s t @@ -40,12 +44,13 @@ type family D2E env where type family D2AcE env where D2AcE '[] = '[] - D2AcE (t : env) = TAccum t : D2AcE env + D2AcE (t : env) = TAccum (D2 t) : D2AcE env d1 :: STy t -> STy (D1 t) d1 STNil = STNil d1 (STPair a b) = STPair (d1 a) (d1 b) d1 (STEither a b) = STEither (d1 a) (d1 b) +d1 (STLEither a b) = STLEither (d1 a) (d1 b) d1 (STMaybe t) = STMaybe (d1 t) d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t @@ -55,27 +60,34 @@ d1e :: SList STy env -> SList STy (D1E env) d1e SNil = SNil d1e (t `SCons` env) = d1 t `SCons` d1e env +d2M :: STy t -> SMTy (D2 t) +d2M STNil = SMTNil +d2M (STPair a b) = SMTPair (d2M a) (d2M b) +d2M (STEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STLEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STMaybe t) = SMTMaybe (d2M t) +d2M (STArr n t) = SMTArr n (d2M t) +d2M (STScal t) = case t of + STI32 -> SMTNil + STI64 -> SMTNil + STF32 -> SMTScal STF32 + STF64 -> SMTScal STF64 + STBool -> SMTNil +d2M STAccum{} = error "Accumulators not allowed in input program" + d2 :: STy t -> STy (D2 t) -d2 STNil = STNil -d2 (STPair a b) = STMaybe (STPair (d2 a) (d2 b)) -d2 (STEither a b) = STMaybe (STEither (d2 a) (d2 b)) -d2 (STMaybe t) = STMaybe (d2 t) -d2 (STArr n t) = STArr n (d2 t) -d2 (STScal t) = case t of - STI32 -> STNil - STI64 -> STNil - STF32 -> STScal STF32 - STF64 -> STScal STF64 - STBool -> STNil -d2 STAccum{} = error "Accumulators not allowed in input program" +d2 = fromSMTy . d2M + +d2eM :: SList STy env -> SList SMTy (D2E env) +d2eM SNil = SNil +d2eM (t `SCons` ts) = d2M t `SCons` d2eM ts d2e :: SList STy env -> SList STy (D2E env) -d2e SNil = SNil -d2e (t `SCons` ts) = d2 t `SCons` d2e ts +d2e = slistMap fromSMTy . d2eM d2ace :: SList STy env -> SList STy (D2AcE env) d2ace SNil = SNil -d2ace (t `SCons` ts) = STAccum t `SCons` d2ace ts +d2ace (t `SCons` ts) = STAccum (d2M t) `SCons` d2ace ts data CHADConfig = CHADConfig @@ -85,6 +97,8 @@ data CHADConfig = CHADConfig chcCaseArrayAccum :: Bool , -- | Introduce top-level arguments containing arrays in accumulator mode. chcArgArrayAccum :: Bool + , -- | Place with-blocks around array variable scopes, and redirect accumulations there. + chcSmartWith :: Bool } deriving (Show) @@ -93,12 +107,14 @@ defaultConfig = CHADConfig { chcLetArrayAccum = False , chcCaseArrayAccum = False , chcArgArrayAccum = False + , chcSmartWith = False } chcSetAccum :: CHADConfig -> CHADConfig chcSetAccum c = c { chcLetArrayAccum = True , chcCaseArrayAccum = True - , chcArgArrayAccum = True } + , chcArgArrayAccum = True + , chcSmartWith = True } ------------------------------------ LEMMAS ------------------------------------ @@ -106,3 +122,32 @@ chcSetAccum c = c { chcLetArrayAccum = True indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) indexTupD1Id SZ = Refl indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + +lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil +lemZeroInfoScal STI32 = Refl +lemZeroInfoScal STI64 = Refl +lemZeroInfoScal STF32 = Refl +lemZeroInfoScal STF64 = Refl +lemZeroInfoScal STBool = Refl + +lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil +lemDeepZeroInfoScal STI32 = Refl +lemDeepZeroInfoScal STI64 = Refl +lemDeepZeroInfoScal STF32 = Refl +lemDeepZeroInfoScal STF64 = Refl +lemDeepZeroInfoScal STBool = Refl + +d1Identity :: STy t -> D1 t :~: t +d1Identity = \case + STNil -> Refl + STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STMaybe t | Refl <- d1Identity t -> Refl + STArr _ t | Refl <- d1Identity t -> Refl + STScal _ -> Refl + STAccum{} -> error "Accumulators not allowed in input program" + +d1eIdentity :: SList STy env -> D1E env :~: env +d1eIdentity SNil = Refl +d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl |