diff options
Diffstat (limited to 'src/CHAD/Drev/Types.hs')
| -rw-r--r-- | src/CHAD/Drev/Types.hs | 153 |
1 files changed, 153 insertions, 0 deletions
diff --git a/src/CHAD/Drev/Types.hs b/src/CHAD/Drev/Types.hs new file mode 100644 index 0000000..367a974 --- /dev/null +++ b/src/CHAD/Drev/Types.hs @@ -0,0 +1,153 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Drev.Types where + +import CHAD.AST.Accum +import CHAD.AST.Types +import CHAD.Data + + +type family D1 t where + D1 TNil = TNil + D1 (TPair a b) = TPair (D1 a) (D1 b) + D1 (TEither a b) = TEither (D1 a) (D1 b) + D1 (TLEither a b) = TLEither (D1 a) (D1 b) + D1 (TMaybe a) = TMaybe (D1 a) + D1 (TArr n t) = TArr n (D1 t) + D1 (TScal t) = TScal t + +type family D2 t where + D2 TNil = TNil + D2 (TPair a b) = TPair (D2 a) (D2 b) + D2 (TEither a b) = TLEither (D2 a) (D2 b) + D2 (TLEither a b) = TLEither (D2 a) (D2 b) + D2 (TMaybe t) = TMaybe (D2 t) + D2 (TArr n t) = TArr n (D2 t) + D2 (TScal t) = D2s t + +type family D2s t where + D2s TI32 = TNil + D2s TI64 = TNil + D2s TF32 = TScal TF32 + D2s TF64 = TScal TF64 + D2s TBool = TNil + +type family D1E env where + D1E '[] = '[] + D1E (t : env) = D1 t : D1E env + +type family D2E env where + D2E '[] = '[] + D2E (t : env) = D2 t : D2E env + +type family D2AcE env where + D2AcE '[] = '[] + D2AcE (t : env) = TAccum (D2 t) : D2AcE env + +d1 :: STy t -> STy (D1 t) +d1 STNil = STNil +d1 (STPair a b) = STPair (d1 a) (d1 b) +d1 (STEither a b) = STEither (d1 a) (d1 b) +d1 (STLEither a b) = STLEither (d1 a) (d1 b) +d1 (STMaybe t) = STMaybe (d1 t) +d1 (STArr n t) = STArr n (d1 t) +d1 (STScal t) = STScal t +d1 STAccum{} = error "Accumulators not allowed in input program" + +d1e :: SList STy env -> SList STy (D1E env) +d1e SNil = SNil +d1e (t `SCons` env) = d1 t `SCons` d1e env + +d2M :: STy t -> SMTy (D2 t) +d2M STNil = SMTNil +d2M (STPair a b) = SMTPair (d2M a) (d2M b) +d2M (STEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STLEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STMaybe t) = SMTMaybe (d2M t) +d2M (STArr n t) = SMTArr n (d2M t) +d2M (STScal t) = case t of + STI32 -> SMTNil + STI64 -> SMTNil + STF32 -> SMTScal STF32 + STF64 -> SMTScal STF64 + STBool -> SMTNil +d2M STAccum{} = error "Accumulators not allowed in input program" + +d2 :: STy t -> STy (D2 t) +d2 = fromSMTy . d2M + +d2eM :: SList STy env -> SList SMTy (D2E env) +d2eM SNil = SNil +d2eM (t `SCons` ts) = d2M t `SCons` d2eM ts + +d2e :: SList STy env -> SList STy (D2E env) +d2e = slistMap fromSMTy . d2eM + +d2ace :: SList STy env -> SList STy (D2AcE env) +d2ace SNil = SNil +d2ace (t `SCons` ts) = STAccum (d2M t) `SCons` d2ace ts + + +data CHADConfig = CHADConfig + { -- | D[let] will bind variables containing arrays in accumulator mode. + chcLetArrayAccum :: Bool + , -- | D[case] will bind variables containing arrays in accumulator mode. + chcCaseArrayAccum :: Bool + , -- | Introduce top-level arguments containing arrays in accumulator mode. + chcArgArrayAccum :: Bool + , -- | Place with-blocks around array variable scopes, and redirect accumulations there. + chcSmartWith :: Bool + } + deriving (Show) + +defaultConfig :: CHADConfig +defaultConfig = CHADConfig + { chcLetArrayAccum = False + , chcCaseArrayAccum = False + , chcArgArrayAccum = False + , chcSmartWith = False + } + +chcSetAccum :: CHADConfig -> CHADConfig +chcSetAccum c = c { chcLetArrayAccum = True + , chcCaseArrayAccum = True + , chcArgArrayAccum = True + , chcSmartWith = True } + + +------------------------------------ LEMMAS ------------------------------------ + +indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) +indexTupD1Id SZ = Refl +indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + +lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil +lemZeroInfoScal STI32 = Refl +lemZeroInfoScal STI64 = Refl +lemZeroInfoScal STF32 = Refl +lemZeroInfoScal STF64 = Refl +lemZeroInfoScal STBool = Refl + +lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil +lemDeepZeroInfoScal STI32 = Refl +lemDeepZeroInfoScal STI64 = Refl +lemDeepZeroInfoScal STF32 = Refl +lemDeepZeroInfoScal STF64 = Refl +lemDeepZeroInfoScal STBool = Refl + +d1Identity :: STy t -> D1 t :~: t +d1Identity = \case + STNil -> Refl + STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STMaybe t | Refl <- d1Identity t -> Refl + STArr _ t | Refl <- d1Identity t -> Refl + STScal _ -> Refl + STAccum{} -> error "Accumulators not allowed in input program" + +d1eIdentity :: SList STy env -> D1E env :~: env +d1eIdentity SNil = Refl +d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl |
