diff options
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | AD.hs | 152 | ||||
-rw-r--r-- | AST.hs | 288 | ||||
-rw-r--r-- | Examples.hs | 28 | ||||
-rw-r--r-- | Gradient.hs | 23 | ||||
-rw-r--r-- | Language.hs | 52 | ||||
-rw-r--r-- | Main.hs | 5 | ||||
-rw-r--r-- | README.md | 7 | ||||
-rw-r--r-- | Repl.hs | 11 | ||||
-rw-r--r-- | Simplify.hs | 203 | ||||
-rw-r--r-- | Sink.hs | 47 | ||||
-rw-r--r-- | ftilde.cabal | 27 |
12 files changed, 844 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c33954f --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +dist-newstyle/ @@ -0,0 +1,152 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module AD ( + Dual, + dual, + ad, +) where + +import Data.Bifunctor +import Data.Type.Equality + +import AST +import qualified Language as L +import Sink + + +-- | Dual-number version of a type. Pairs up every Double with a tangent copy. +type family Dual a where + Dual () = () + Dual Int = Int + Dual Bool = Bool + Dual Double = (Double, Double) + Dual (a, b) = (Dual a, Dual b) + Dual (Array sh a) = Array sh (Dual a) + Dual (a -> b) = Dual a -> Dual b + +data DEnv env env' where + ETop :: DEnv env env + ECons :: Type a -> DEnv env env' -> DEnv (a ': env) (Dual a ': env') + +-- | Convert a type to its 'Dual' variant. +dual :: Type a -> Type (Dual a) +dual (TFun a b) = TFun (dual a) (dual b) +dual TInt = TInt +dual TBool = TBool +dual TDouble = TPair TDouble TDouble +dual (TArray sht t) = TArray sht (dual t) +dual TNil = TNil +dual (TPair a b) = TPair (dual a) (dual b) + +-- | Forward AD. +-- +-- Since @'Dual' (a -> b) = 'Dual' a -> 'Dual' b@, calling 'ad' on a +-- function-typed expression gives the most useful results. +ad :: Exp env a -> Exp env (Dual a) +ad = ad' ETop + +ad' :: DEnv env env' -> Exp env a -> Exp env' (Dual a) +ad' env = \case + App e1 e2 -> App (ad' env e1) (ad' env e2) + Lam t e -> Lam (dual t) (ad' (ECons t env) e) + Var t i -> + case convIdx env i of + Left freei -> App (duale t) (Var t freei) + Right duali -> Var (dual t) duali + Let e1 e2 -> Let (ad' env e1) (ad' (ECons (typeof e1) env) e2) + Lit l -> App (duale (literalType l)) (Lit l) + Cond e1 e2 e3 -> Cond (ad' env e1) (ad' env e2) (ad' env e3) + Const CAddI -> Const CAddI + Const CSubI -> Const CSubI + Const CMulI -> Const CMulI + Const CDivI -> Const CDivI + Const CAddF -> + let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero + in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) + (Pair (App (Const CAddF) (Pair (Fst (Fst v)) (Fst (Snd v)))) + (App (Const CAddF) (Pair (Snd (Fst v)) (Snd (Snd v))))) + Const CSubF -> + let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero + in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) + (Pair (App (Const CSubF) (Pair (Fst (Fst v)) (Fst (Snd v)))) + (App (Const CSubF) (Pair (Snd (Fst v)) (Snd (Snd v))))) + Const CMulF -> + let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero + in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) + (Pair (App (Const CMulF) (Pair (Fst (Fst v)) (Fst (Snd v)))) + (App (Const CAddF) (Pair + (App (Const CMulF) (Pair (Fst (Fst v)) (Snd (Snd v)))) + (App (Const CMulF) (Pair (Fst (Snd v)) (Snd (Fst v))))))) + Const _ -> undefined + Pair e1 e2 -> Pair (ad' env e1) (ad' env e2) + Fst e -> Fst (ad' env e) + Snd e -> Snd (ad' env e) + Build sht e1 e2 + | Refl <- prfDualSht sht + -> Build sht (ad' env e1) (ad' env e2) + Ifold sht e1 e2 e3 + | Refl <- prfDualSht sht + -> Ifold sht (ad' env e1) (ad' env e2) (ad' env e3) + Index e1 e2 + | TArray sht _ <- typeof e1 + , Refl <- prfDualSht sht + -> Index (ad' env e1) (ad' env e2) + Shape e + | TArray sht _ <- typeof e + , Refl <- prfDualSht sht + -> Shape (ad' env e) + +convIdx :: DEnv env env' -> Idx env a -> Either (Idx env' a) (Idx env' (Dual a)) +convIdx ETop i = Left i +convIdx (ECons _ _) Zero = Right Zero +convIdx (ECons _ env) (Succ i) = bimap Succ Succ (convIdx env i) + +duale :: Type a -> Exp env (a -> Dual a) +duale topty = Lam topty (go topty) + where + go :: forall a env. Type a -> Exp (a ': env) (Dual a) + go ty = case ty of + TInt -> ref + TBool -> ref + TDouble -> Pair ref (Lit (LDouble 0)) + TArray _ t -> L.map (Lam t (go t)) ref + TNil -> Lit LNil + TPair t1 t2 -> Pair (Let (Fst ref) (go t1)) (Let (Snd ref) (go t2)) + TFun t1 t2 -> Lam (dual t1) + (App (duale t2) + (App (sinkExp1 ref) + (App (unduale t1) + (Var (dual t1) Zero)))) + where + ref :: Exp (a ': env) a + ref = Var ty Zero + +unduale :: Type a -> Exp env (Dual a -> a) +unduale topty = Lam (dual topty) (go topty) + where + go :: forall a env. Type a -> Exp (Dual a ': env) a + go ty = case ty of + TInt -> ref + TBool -> ref + TDouble -> Fst ref + TArray _ t -> L.map (Lam (dual t) (go t)) ref + TNil -> Lit LNil + TPair t1 t2 -> Pair (Let (Fst ref) (go t1)) (Let (Snd ref) (go t2)) + TFun t1 t2 -> Lam t1 + (App (unduale t2) + (App (sinkExp1 ref) + (App (duale t1) + (Var t1 Zero)))) + where + ref :: Exp (Dual a ': env) (Dual a) + ref = Var (dual ty) Zero + +prfDualSht :: ShapeType sh -> sh :~: Dual sh +prfDualSht STZ = Refl +prfDualSht (STC sht) + | Refl <- prfDualSht sht + = Refl @@ -0,0 +1,288 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +module AST where + +import Data.GADT.Compare +import Data.Type.Equality +import qualified Data.Vector as V +import Data.Vector (Vector) + + +data Exp env a where + App :: Exp env (a -> b) -> Exp env a -> Exp env b + Lam :: Type t -> Exp (t ': env) a -> Exp env (t -> a) + Var :: Type a -> Idx env a -> Exp env a + Let :: Exp env t -> Exp (t ': env) a -> Exp env a + Lit :: Literal a -> Exp env a + Cond :: Exp env Bool -> Exp env a -> Exp env a -> Exp env a + Const :: Constant a -> Exp env a + Pair :: Exp env a -> Exp env b -> Exp env (a, b) + Fst :: Exp env (a, b) -> Exp env a + Snd :: Exp env (a, b) -> Exp env b + Build :: ShapeType sh -> Exp env sh -> Exp env (sh -> a) -> Exp env (Array sh a) + Ifold :: ShapeType sh -> Exp env ((s, sh) -> s) -> Exp env s -> Exp env sh -> Exp env s + Index :: Exp env (Array sh a) -> Exp env sh -> Exp env a + Shape :: Exp env (Array sh a) -> Exp env sh + +data Constant a where + CAddI :: Constant ((Int, Int) -> Int) + CSubI :: Constant ((Int, Int) -> Int) + CMulI :: Constant ((Int, Int) -> Int) + CDivI :: Constant ((Int, Int) -> Int) + CAddF :: Constant ((Double, Double) -> Double) + CSubF :: Constant ((Double, Double) -> Double) + CMulF :: Constant ((Double, Double) -> Double) + CDivF :: Constant ((Double, Double) -> Double) + CLog :: Constant (Double -> Double) + CExp :: Constant (Double -> Double) + CtoF :: Constant (Int -> Double) + CRound :: Constant (Double -> Int) + + CLtI :: Constant ((Int, Int) -> Bool) + CLtF :: Constant ((Double, Double) -> Bool) + CEq :: Type a -> Constant ((a, a) -> Bool) + CAnd :: Constant ((Bool, Bool) -> Bool) + COr :: Constant ((Bool, Bool) -> Bool) + CNot :: Constant (Bool -> Bool) + +data Type a where + TInt :: Type Int + TBool :: Type Bool + TDouble :: Type Double + TArray :: ShapeType sh -> Type a -> Type (Array sh a) + TNil :: Type () + TPair :: Type a -> Type b -> Type (a, b) + TFun :: Type a -> Type b -> Type (a -> b) + +data Idx env a where + Zero :: Idx (a ': env) a + Succ :: Idx env a -> Idx (t ': env) a + +data Literal a where + LInt :: Int -> Literal Int + LBool :: Bool -> Literal Bool + LDouble :: Double -> Literal Double + LArray :: Array sh a -> Literal (Array sh a) + LShape :: Shape sh -> Literal sh + LNil :: Literal () + LPair :: Literal a -> Literal b -> Literal (a, b) + +data Shape sh where + Z :: Shape () + (:.) :: Int -> Shape sh -> Shape (Int, sh) + +data ShapeType sh where + STZ :: ShapeType () + STC :: ShapeType sh -> ShapeType (Int, sh) + +data Array sh a where + Array :: Shape sh -> Type a -> Vector a -> Array sh a + +deriving instance Show (Exp env a) +deriving instance Show (Constant a) +deriving instance Show (Type a) +deriving instance Show (Idx env a) +deriving instance Show (Literal a) +deriving instance Show (Shape a) +deriving instance Show (ShapeType a) + +instance Show (Array sh a) where + showsPrec p (Array sh t v) = + showParen (p > 10) $ + showString "Array " + . showsPrec 11 sh + . showsPrec 11 t + . (case typeHasShow t of + Just Has -> showsPrec 11 v + Nothing -> showString ("[_ * " ++ show (V.length v) ++ "]")) + +deriving instance Eq (Type a) +deriving instance Eq (Shape sh) +deriving instance Eq (ShapeType sh) +deriving instance Eq a => Eq (Array sh a) + +instance GEq (Exp env) where + geq (App a b) (App a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl + geq App{} _ = Nothing + geq (Lam t e) (Lam t' e') | Just Refl <- geq t t', Just Refl <- geq e e' = Just Refl + geq Lam{} _ = Nothing + geq (Var t i) (Var t' i') | Just Refl <- geq t t', Just Refl <- geq i i' = Just Refl + geq Var{} _ = Nothing + geq (Let a b) (Let a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl + geq Let{} _ = Nothing + geq (Lit l) (Lit l') | Just Refl <- geq l l' = Just Refl + geq Lit{} _ = Nothing + geq (Cond a b c) (Cond a' b' c') | Just Refl <- geq a a', Just Refl <- geq b b', Just Refl <- geq c c' = Just Refl + geq Cond{} _ = Nothing + geq (Const c) (Const c') | Just Refl <- geq c c' = Just Refl + geq Const{} _ = Nothing + geq (Pair a b) (Pair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl + geq Pair{} _ = Nothing + geq (Fst a) (Fst a') | Just Refl <- geq a a' = Just Refl + geq Fst{} _ = Nothing + geq (Snd a) (Snd a') | Just Refl <- geq a a' = Just Refl + geq Snd{} _ = Nothing + geq (Build t a b) (Build t' a' b') | Just Refl <- geq t t', Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl + geq Build{} _ = Nothing + geq (Ifold t a b c) (Ifold t' a' b' c') | Just Refl <- geq t t', Just Refl <- geq a a', Just Refl <- geq b b' , Just Refl <- geq c c' + = Just Refl + geq Ifold{} _ = Nothing + geq (Index a b) (Index a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl + geq Index{} _ = Nothing + geq (Shape a) (Shape a') | Just Refl <- geq a a' = Just Refl + geq Shape{} _ = Nothing + +instance GEq Constant where + geq CAddI CAddI = Just Refl ; geq CAddI _ = Nothing + geq CSubI CSubI = Just Refl ; geq CSubI _ = Nothing + geq CMulI CMulI = Just Refl ; geq CMulI _ = Nothing + geq CDivI CDivI = Just Refl ; geq CDivI _ = Nothing + geq CAddF CAddF = Just Refl ; geq CAddF _ = Nothing + geq CSubF CSubF = Just Refl ; geq CSubF _ = Nothing + geq CMulF CMulF = Just Refl ; geq CMulF _ = Nothing + geq CDivF CDivF = Just Refl ; geq CDivF _ = Nothing + geq CLog CLog = Just Refl ; geq CLog _ = Nothing + geq CExp CExp = Just Refl ; geq CExp _ = Nothing + geq CtoF CtoF = Just Refl ; geq CtoF _ = Nothing + geq CRound CRound = Just Refl ; geq CRound _ = Nothing + geq CLtI CLtI = Just Refl ; geq CLtI _ = Nothing + geq CLtF CLtF = Just Refl ; geq CLtF _ = Nothing + geq (CEq t) (CEq t') | Just Refl <- geq t t' = Just Refl ; geq CEq{} _ = Nothing + geq CAnd CAnd = Just Refl ; geq CAnd _ = Nothing + geq COr COr = Just Refl ; geq COr _ = Nothing + geq CNot CNot = Just Refl ; geq CNot _ = Nothing + +instance GEq Type where + geq TInt TInt = Just Refl ; geq TInt _ = Nothing + geq TBool TBool = Just Refl ; geq TBool _ = Nothing + geq TDouble TDouble = Just Refl ; geq TDouble _ = Nothing + geq (TArray sht t) (TArray sht' t') | Just Refl <- geq sht sht', Just Refl <- geq t t' = Just Refl ; geq TArray{} _ = Nothing + geq TNil TNil = Just Refl ; geq TNil _ = Nothing + geq (TPair a b) (TPair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq TPair{} _ = Nothing + geq (TFun a b) (TFun a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq TFun{} _ = Nothing + +instance GEq (Idx env) where + geq Zero Zero = Just Refl + geq (Succ i) (Succ i') | Just Refl <- geq i i' = Just Refl + geq _ _ = Nothing + +instance GEq Literal where + geq (LInt a) (LInt a') | a == a' = Just Refl ; geq LInt{} _ = Nothing + geq (LBool a) (LBool a') | a == a' = Just Refl ; geq LBool{} _ = Nothing + geq (LDouble a) (LDouble a') | a == a' = Just Refl ; geq LDouble{} _ = Nothing + geq (LArray (Array sht t v)) (LArray (Array sht' t' v')) + | Just Refl <- geq sht sht' + , Just Refl <- geq t t' + = case typeHasEq t of + Just Has | v == v' -> Just Refl + | otherwise -> Nothing + Nothing -> error "GEq Literal: Literal array of incomparable values" + geq LArray{} _ = Nothing + geq (LShape a) (LShape a') | Just Refl <- geq a a' = Just Refl ; geq LShape{} _ = Nothing + geq LNil LNil = Just Refl ; geq LNil _ = Nothing + geq (LPair a b) (LPair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq LPair{} _ = Nothing + +instance GEq Shape where + geq Z Z = Just Refl + geq (n :. sh) (n' :. sh') | n == n', Just Refl <- geq sh sh' = Just Refl + geq _ _ = Nothing + +instance GEq ShapeType where + geq STZ STZ = Just Refl + geq (STC sht) (STC sht') | Just Refl <- geq sht sht' = Just Refl + geq _ _ = Nothing + +shapeType :: Shape sh -> ShapeType sh +shapeType Z = STZ +shapeType (_ :. sh) = STC (shapeType sh) + +shapeType' :: Shape sh -> Type sh +shapeType' Z = TNil +shapeType' (_ :. sh) = TPair TInt (shapeType' sh) + +shapeTypeType :: ShapeType sh -> Type sh +shapeTypeType STZ = TNil +shapeTypeType (STC sht) = TPair TInt (shapeTypeType sht) + +literalType :: Literal a -> Type a +literalType LInt{} = TInt +literalType LBool{} = TBool +literalType LDouble{} = TDouble +literalType (LArray (Array sh t _)) = TArray (shapeType sh) t +literalType (LShape sh) = shapeType' sh +literalType LNil{} = TNil +literalType (LPair a b) = TPair (literalType a) (literalType b) + +constType :: Constant a -> Type a +constType CAddI = TFun (TPair TInt TInt) TInt +constType CSubI = TFun (TPair TInt TInt) TInt +constType CMulI = TFun (TPair TInt TInt) TInt +constType CDivI = TFun (TPair TInt TInt) TInt +constType CAddF = TFun (TPair TDouble TDouble) TDouble +constType CSubF = TFun (TPair TDouble TDouble) TDouble +constType CMulF = TFun (TPair TDouble TDouble) TDouble +constType CDivF = TFun (TPair TDouble TDouble) TDouble +constType CLog = TFun TDouble TDouble +constType CExp = TFun TDouble TDouble +constType CtoF = TFun TInt TDouble +constType CRound = TFun TDouble TInt +constType CLtI = TFun (TPair TInt TInt) TBool +constType CLtF = TFun (TPair TDouble TDouble) TBool +constType (CEq t) = TFun (TPair t t) TBool +constType CAnd = TFun (TPair TBool TBool) TBool +constType COr = TFun (TPair TBool TBool) TBool +constType CNot = TFun TBool TBool + +typeof :: Exp env a -> Type a +typeof (App e _) = let TFun _ t = typeof e in t +typeof (Lam t e) = TFun t (typeof e) +typeof (Var t _) = t +typeof (Let _ e) = typeof e +typeof (Lit l) = literalType l +typeof (Cond _ e _) = typeof e +typeof (Const c) = constType c +typeof (Pair e1 e2) = TPair (typeof e1) (typeof e2) +typeof (Fst e) = let TPair t _ = typeof e in t +typeof (Snd e) = let TPair _ t = typeof e in t +typeof (Build sht _ e) = let TFun _ t = typeof e in TArray sht t +typeof (Ifold _ _ e _) = typeof e +typeof (Index e _) = let TArray _ t = typeof e in t +typeof (Shape e) = let TArray sht _ = typeof e in shapeTypeType sht + +data Has c a where + Has :: c a => Has c a + +typeHasShow :: Type a -> Maybe (Has Show a) +typeHasShow TInt = Just Has +typeHasShow TBool = Just Has +typeHasShow TDouble = Just Has +typeHasShow TArray{} = Just Has +typeHasShow TNil = Just Has +typeHasShow (TPair a b) + | Just Has <- typeHasShow a + , Just Has <- typeHasShow b + = Just Has + | otherwise + = Nothing +typeHasShow TFun{} = Nothing + +typeHasEq :: Type a -> Maybe (Has Eq a) +typeHasEq TInt = Just Has +typeHasEq TBool = Just Has +typeHasEq TDouble = Just Has +typeHasEq (TArray _ t) + | Just Has <- typeHasEq t + = Just Has + | otherwise + = Nothing +typeHasEq TNil = Just Has +typeHasEq (TPair a b) + | Just Has <- typeHasEq a + , Just Has <- typeHasEq b + = Just Has + | otherwise + = Nothing +typeHasEq TFun{} = Nothing diff --git a/Examples.hs b/Examples.hs new file mode 100644 index 0000000..9d9cda7 --- /dev/null +++ b/Examples.hs @@ -0,0 +1,28 @@ +module Examples where + +import AST +import qualified Language as L + + +sumSq :: Exp env (Array (Int, ()) Double -> Double) +sumSq = Lam (TArray (STC STZ) TDouble) + (L.sum (App mapSq (Var (TArray (STC STZ) TDouble) Zero))) + +mapSq :: Exp env (Array (Int, ()) Double -> Array (Int, ()) Double) +mapSq = + Lam (TArray (STC STZ) TDouble) + (L.map (Lam TDouble + (App (Const CMulF) + (Pair (Var TDouble Zero) (Var TDouble Zero)))) + (Var (TArray (STC STZ) TDouble) Zero)) + +mapSqIota :: Exp env (Array (Int, ()) Double) +mapSqIota = + L.map (Lam TDouble + (App (Const CMulF) + (Pair (Var TDouble Zero) (Var TDouble Zero)))) + (Build (STC STZ) + (Pair (Lit (LInt 5)) (Lit LNil)) + (Lam (TPair TInt TNil) + (App (Const CtoF) + (Fst (Var (TPair TInt TNil) Zero))))) diff --git a/Gradient.hs b/Gradient.hs new file mode 100644 index 0000000..57ee904 --- /dev/null +++ b/Gradient.hs @@ -0,0 +1,23 @@ +{-# LANGUAGE GADTs #-} +module Gradient where + +import AD +import AST +import qualified Language as L +import Sink + + +gradient :: Exp env (Array sh Double -> Double) -> Exp env (Array sh Double -> Array sh Double) +gradient func = + let TFun tarr@(TArray sht _) _ = typeof func + idxt = shapeTypeType sht + func' = ad func + in Lam tarr + (Build sht + (Shape (Var tarr Zero)) + (Lam idxt + (Snd (App (sinkExp2 func') + (L.zip (Var tarr (Succ Zero)) + (L.oneHot sht + (Shape (Var tarr (Succ Zero))) + (Var idxt Zero))))))) diff --git a/Language.hs b/Language.hs new file mode 100644 index 0000000..e16cf7c --- /dev/null +++ b/Language.hs @@ -0,0 +1,52 @@ +{-# LANGUAGE GADTs #-} + +{-| This module is intended to be imported qualified, perhaps as @L@. -} +module Language where + +import AST +import Sink + + +map :: Exp env (a -> b) -> Exp env (Array sh a) -> Exp env (Array sh b) +map f e = + let ty@(TArray sht _) = typeof e + sht' = shapeTypeType sht + in Let e + (Build sht (Shape (Var ty Zero)) + (Lam sht' + (App (sinkExp2 f) + (Index (Var ty (Succ Zero)) + (Var sht' Zero))))) + +sum :: Exp env (Array (Int, ()) Double) -> Exp env Double +sum e = + let ty@(TArray sht _) = typeof e + in Let e + (Ifold sht + (Lam (TPair TDouble (TPair TInt TNil)) + (App (Const CAddF) (Pair + (Fst (Var (TPair TDouble (TPair TInt TNil)) Zero)) + (Index (Var ty (Succ Zero)) (Snd (Var (TPair TDouble (TPair TInt TNil)) Zero)))))) + (Lit (LDouble 0)) + (Shape (Var ty Zero))) + +-- | The two input arrays are assumed to be the same size. +zip :: Exp env (Array sh a) -> Exp env (Array sh b) -> Exp env (Array sh (a, b)) +zip a b = + let tarr@(TArray sht _) = typeof a + idxt = shapeTypeType sht + in Let a + (Build sht + (Shape (Var tarr Zero)) + (Lam idxt + (Pair (Index (Var tarr (Succ Zero)) (Var idxt Zero)) + (Index (sinkExp2 b) (Var idxt Zero))))) + +oneHot :: ShapeType sh -> Exp env sh -> Exp env sh -> Exp env (Array sh Double) +oneHot sht sh idx = + let idxt = shapeTypeType sht + in Build sht sh + (Lam idxt + (Cond (App (Const (CEq idxt)) (Pair (Var idxt Zero) (sinkExp1 idx))) + (Lit (LDouble 1)) + (Lit (LDouble 0)))) @@ -0,0 +1,5 @@ +module Main where + + +main :: IO () +main = return () diff --git a/README.md b/README.md new file mode 100644 index 0000000..2d71583 --- /dev/null +++ b/README.md @@ -0,0 +1,7 @@ +# \tilde F + +Re-implementation of the language, AD transformation, and optimisation framework from: + Amir Shaikhha, Andrew Fitzgibbon, Dimitrios Vytiniotis, and Simon Peyton + Jones. 2019. Efficient Differentiable Programming in a Functional + Array-Processing Language. Proc. ACM Program. Lang. 3, ICFP, Article 97 + (August 2019), 30 pages. https://doi.org/10.1145/3341701 @@ -0,0 +1,11 @@ +{-# OPTIONS -Wno-unused-imports #-} +module Repl where + + +import AST +import AD +import Examples +import Gradient +import qualified Language as L +import Simplify +import Sink diff --git a/Simplify.hs b/Simplify.hs new file mode 100644 index 0000000..9ceaef9 --- /dev/null +++ b/Simplify.hs @@ -0,0 +1,203 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module Simplify ( + simplify, + simplifyFix, +) where + +import Data.Bifunctor +import Data.GADT.Compare +import qualified Data.Kind as Kind +import Data.List (find) +import Data.Type.Equality + +import AST +import Sink + + +data family Info (env :: [Kind.Type]) a +-- data instance Info env Int = InfoInt +-- data instance Info env Bool = InfoBool +-- data instance Info env Double = InfoDouble +data instance Info env (Array sh t) = InfoArray (Exp env sh) +-- data instance Info env () = InfoNil +data instance Info env (a, b) = InfoPair (Info env a) (Info env b) +-- data instance Info env (a -> b) = InfoFun + +data IEnv env where + ITop :: IEnv env + ICons :: Type a -> Maybe (Info (a ': env) a) -> IEnv env -> IEnv (a ': env) + +sinkInfo1 :: Type a -> Info env a -> Info (t ': env) a +sinkInfo1 TArray{} (InfoArray e) = InfoArray (sinkExp1 e) +sinkInfo1 (TPair t1 t2) (InfoPair a b) = InfoPair (sinkInfo1 t1 a) (sinkInfo1 t2 b) +sinkInfo1 _ _ = error "Unknown info in sinkInfo1" + +iprj :: IEnv env -> Idx env a -> Maybe (Type a, Info env a) +iprj ITop _ = Nothing +iprj (ICons t m _) Zero = (t,) <$> m +iprj (ICons _ _ env) (Succ i) = (\(t, m) -> (t, sinkInfo1 t m)) <$> iprj env i + +simplifyFix :: Exp env a -> Exp env a +simplifyFix e = + let maxTimes = 4 + es = take (maxTimes + 1) (iterate simplify e) + pairs = zip es (tail es) + in case find (\(a,b) -> case geq a b of Just Refl -> True ; _ -> False) pairs of + Just (e', _) -> e' + Nothing -> error "Simplification doesn't converge!" + +simplify :: Exp env a -> Exp env a +simplify = fst . simplify' ITop + +simplify' :: IEnv env -> Exp env a -> (Exp env a, Maybe (Info env a)) +simplify' env = \case + App a b -> (simplifyApp (fst (simplify' env a)) (fst (simplify' env b)), Nothing) + Lam t e -> (Lam t (fst (simplify' (ICons t Nothing env) e)), Nothing) + Var t i -> (Var t i, snd <$> iprj env i) + Let arg e -> + let (arg', info) = simplify' env arg + env' = ICons (typeof arg) (sinkInfo1 (typeof arg) <$> info) env + in (simplifyLet arg' (fst (simplify' env' e)), Nothing) + Lit l -> (Lit l, Nothing) + Cond a b c -> + (Cond (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing) + Const c -> (Const c, Nothing) + Pair a b -> + let (a', ia) = simplify' env a + (b', ib) = simplify' env b + in (Pair a' b', InfoPair <$> ia <*> ib) + Fst e -> bimap simplifyFst (fmap (\(InfoPair i _) -> i)) (simplify' env e) + Snd e -> bimap simplifySnd (fmap (\(InfoPair _ i) -> i)) (simplify' env e) + Build sht a b -> + let a' = fst (simplify' env a) + in (Build sht a' (fst (simplify' env b)), Just (InfoArray a')) + Ifold sht a b c -> (Ifold sht (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing) + Index a b -> (simplifyIndex (fst (simplify' env a)) (fst (simplify' env b)), Nothing) + Shape e -> + case simplify' env e of + (_, Just (InfoArray she)) -> (she, Nothing) + (e', _) -> (Shape e', Nothing) + +simplifyApp :: Exp env (a -> b) -> Exp env a -> Exp env b +simplifyApp (Const CAddI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a + b)) +simplifyApp (Const CSubI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a - b)) +simplifyApp (Const CMulI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a * b)) +simplifyApp (Const CDivI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a `div` b)) +simplifyApp (Const CAddF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a + b)) +simplifyApp (Const CSubF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a - b)) +simplifyApp (Const CMulF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a * b)) +simplifyApp (Const CDivF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a / b)) +simplifyApp (Const CLog) (Lit (LDouble a)) = Lit (LDouble (log a)) +simplifyApp (Const CExp) (Lit (LDouble a)) = Lit (LDouble (exp a)) +simplifyApp (Const CtoF) (Lit (LInt a)) = Lit (LDouble (fromIntegral a)) +simplifyApp (Const CRound) (Lit (LDouble a)) = Lit (LInt (round a)) +simplifyApp (Const CLtI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LBool (a < b)) +simplifyApp (Const CLtF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LBool (a < b)) +simplifyApp (Const (CEq _)) (Pair a b) + | Just Refl <- geq a b + = Lit (LBool True) +simplifyApp (Const CAnd) (Pair (Lit (LBool a)) (Lit (LBool b))) = Lit (LBool (a && b)) +simplifyApp (Const COr) (Pair (Lit (LBool a)) (Lit (LBool b))) = Lit (LBool (a || b)) +simplifyApp (Const CNot) (Lit (LBool a)) = Lit (LBool (not a)) + +simplifyApp (Lam _ e) arg + | isDuplicable arg || countOcc Zero e <= 1 + = simplify (subst arg e) +simplifyApp (Lam _ e) arg = simplifyLet arg e + +simplifyApp a b = App a b + +simplifyLet :: Exp env a -> Exp (a ': env) b -> Exp env b +simplifyLet arg e + | isDuplicable arg || countOcc Zero e <= 1 + = simplify (subst arg e) +simplifyLet (Pair a b) e = + simplifyLet a $ + simplifyLet (sinkExp1 b) $ + subst' (\t -> \case Zero -> Pair (Var (typeof a) (Succ Zero)) + (Var (typeof b) Zero) + Succ i -> Var t (Succ (Succ i))) + e +simplifyLet (Cond c a b) e + | isDuplicable a && isDuplicable b + = simplifyLet c $ + (subst' (\t -> \case Zero -> Cond (Var TBool Zero) (sinkExp1 a) (sinkExp1 b) + Succ i -> Var t (Succ i)) + e) +simplifyLet a b = Let (simplify a) (simplify b) + +simplifyFst :: Exp env (a, b) -> Exp env a +simplifyFst (Pair e _) = e +simplifyFst (Let a e) = simplifyLet a (simplifyFst e) +simplifyFst e = Fst e + +simplifySnd :: Exp env (a, b) -> Exp env b +simplifySnd (Pair _ e) = e +simplifySnd (Let a e) = simplifyLet a (simplifySnd e) +simplifySnd e = Snd e + +simplifyIndex :: Exp env (Array sh a) -> Exp env sh -> Exp env a +simplifyIndex (Build _ _ f) e = simplifyApp f e +simplifyIndex a e = Index a e + +isDuplicable :: Exp env a -> Bool +isDuplicable (Lam _ e) = isDuplicable e +isDuplicable (Var _ _) = True +isDuplicable (Let a e) = isDuplicable a && isDuplicable e +isDuplicable (Lit (LInt _)) = True +isDuplicable (Lit (LBool _)) = True +isDuplicable (Lit (LDouble _)) = True +isDuplicable (Lit (LShape _)) = True +isDuplicable (Lit LNil) = True +isDuplicable (Lit (LPair l1 l2)) = isDuplicable (Lit l1) && isDuplicable (Lit l2) +isDuplicable (Const _) = True +isDuplicable (Pair a b) = isDuplicable a && isDuplicable b +isDuplicable (Fst e) = isDuplicable e +isDuplicable (Snd e) = isDuplicable e +isDuplicable _ = False + +countOcc :: Idx env t -> Exp env a -> Int +countOcc i (App a b) = countOcc i a + countOcc i b +countOcc i (Lam _ e) = countOcc (Succ i) e +countOcc i (Var _ j) + | Just Refl <- geq i j = 1 + | otherwise = 0 +countOcc i (Let a b) = countOcc i a + countOcc (Succ i) b +countOcc _ (Lit _) = 0 +countOcc i (Cond a b c) = countOcc i a + countOcc i b + countOcc i c +countOcc _ (Const _) = 0 +countOcc i (Pair a b) = countOcc i a + countOcc i b +countOcc i (Fst e) = countOcc i e +countOcc i (Snd e) = countOcc i e +countOcc i (Build _ a b) = countOcc i a + countOcc i b +countOcc i (Ifold _ a b c) = countOcc i a + countOcc i b + countOcc i c +countOcc i (Index a b) = countOcc i a + countOcc i b +countOcc i (Shape e) = countOcc i e + +subst :: Exp env t -> Exp (t ': env) a -> Exp env a +subst arg e = subst' (\t -> \case Zero -> arg ; Succ i -> Var t i) e + +subst' :: (forall t. Type t -> Idx env t -> Exp env' t) -> Exp env a -> Exp env' a +subst' f (App a b) = App (subst' f a) (subst' f b) +subst' f (Lam t e) = + Lam t (subst' (\t' -> \case Zero -> Var t' Zero ; Succ i -> sinkExp1 (f t' i)) e) +subst' f (Var t i) = f t i +subst' f (Let a b) = + Let (subst' f a) + (subst' (\t -> \case Zero -> Var t Zero ; Succ i -> sinkExp1 (f t i)) b) +subst' _ (Lit l) = Lit l +subst' f (Cond a b c) = Cond (subst' f a) (subst' f b) (subst' f c) +subst' _ (Const c) = Const c +subst' f (Pair a b) = Pair (subst' f a) (subst' f b) +subst' f (Fst e) = Fst (subst' f e) +subst' f (Snd e) = Snd (subst' f e) +subst' f (Build sht a b) = Build sht (subst' f a) (subst' f b) +subst' f (Ifold sht a b c) = Ifold sht (subst' f a) (subst' f b) (subst' f c) +subst' f (Index a b) = Index (subst' f a) (subst' f b) +subst' f (Shape e) = Shape (subst' f e) @@ -0,0 +1,47 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeOperators #-} +module Sink where + +import AST + + +newtype env :> env' = Weaken { (>:>) :: forall t'. Idx env t' -> Idx env' t' } + +wId :: env :> env +wId = Weaken id + +wSucc :: env :> env' -> env :> (a ': env') +wSucc (Weaken f) = Weaken (Succ . f) + +wSink :: env :> env' -> (a ': env) :> (a ': env') +wSink w = Weaken (\case Zero -> Zero + Succ i -> Succ (w >:> i)) + +(.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3 +Weaken f .> Weaken g = Weaken (f . g) + +sinkExp :: env :> env' -> Exp env a -> Exp env' a +sinkExp w = \case + App e1 e2 -> App (sinkExp w e1) (sinkExp w e2) + Lam t e -> Lam t (sinkExp (wSink w) e) + Var t i -> Var t (w >:> i) + Let e1 e2 -> Let (sinkExp w e1) (sinkExp (wSink w) e2) + Lit l -> Lit l + Cond e1 e2 e3 -> Cond (sinkExp w e1) (sinkExp w e2) (sinkExp w e3) + Const c -> Const c + Pair e1 e2 -> Pair (sinkExp w e1) (sinkExp w e2) + Fst e -> Fst (sinkExp w e) + Snd e -> Snd (sinkExp w e) + Build sht e1 e2 -> Build sht (sinkExp w e1) (sinkExp w e2) + Ifold sht e1 e2 e3 -> Ifold sht (sinkExp w e1) (sinkExp w e2) (sinkExp w e3) + Index e1 e2 -> Index (sinkExp w e1) (sinkExp w e2) + Shape e -> Shape (sinkExp w e) + +sinkExp1 :: Exp env a -> Exp (t ': env) a +sinkExp1 = sinkExp (wSucc wId) + +sinkExp2 :: Exp env a -> Exp (t1 ': t2 ': env) a +sinkExp2 = sinkExp (wSucc (wSucc wId)) diff --git a/ftilde.cabal b/ftilde.cabal new file mode 100644 index 0000000..1432547 --- /dev/null +++ b/ftilde.cabal @@ -0,0 +1,27 @@ +cabal-version: 2.0 +name: ftilde +synopsis: F tilde (following https://doi.org/10.1145/3341701) +version: 0.1.0.0 +license: MIT +author: Tom Smeding +maintainer: tom@tomsmeding.com +build-type: Simple + +executable ftilde + main-is: Main.hs + other-modules: + AST + AD + Examples + Gradient + Language + Repl + Simplify + Sink + build-depends: + base >= 4.13 && < 4.15, + vector, + some + hs-source-dirs: . + default-language: Haskell2010 + ghc-options: -Wall -O2 -threaded |