From 0fefe4822c65bde81ec4c0da1b5b32a9b411ca79 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Thu, 24 Jun 2021 23:14:54 +0200 Subject: Initial --- .gitignore | 1 + AD.hs | 152 +++++++++++++++++++++++++++++++ AST.hs | 288 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ Examples.hs | 28 ++++++ Gradient.hs | 23 +++++ Language.hs | 52 +++++++++++ Main.hs | 5 ++ README.md | 7 ++ Repl.hs | 11 +++ Simplify.hs | 203 +++++++++++++++++++++++++++++++++++++++++ Sink.hs | 47 ++++++++++ ftilde.cabal | 27 ++++++ 12 files changed, 844 insertions(+) create mode 100644 .gitignore create mode 100644 AD.hs create mode 100644 AST.hs create mode 100644 Examples.hs create mode 100644 Gradient.hs create mode 100644 Language.hs create mode 100644 Main.hs create mode 100644 README.md create mode 100644 Repl.hs create mode 100644 Simplify.hs create mode 100644 Sink.hs create mode 100644 ftilde.cabal diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c33954f --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +dist-newstyle/ diff --git a/AD.hs b/AD.hs new file mode 100644 index 0000000..76fefe4 --- /dev/null +++ b/AD.hs @@ -0,0 +1,152 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module AD ( + Dual, + dual, + ad, +) where + +import Data.Bifunctor +import Data.Type.Equality + +import AST +import qualified Language as L +import Sink + + +-- | Dual-number version of a type. Pairs up every Double with a tangent copy. +type family Dual a where + Dual () = () + Dual Int = Int + Dual Bool = Bool + Dual Double = (Double, Double) + Dual (a, b) = (Dual a, Dual b) + Dual (Array sh a) = Array sh (Dual a) + Dual (a -> b) = Dual a -> Dual b + +data DEnv env env' where + ETop :: DEnv env env + ECons :: Type a -> DEnv env env' -> DEnv (a ': env) (Dual a ': env') + +-- | Convert a type to its 'Dual' variant. +dual :: Type a -> Type (Dual a) +dual (TFun a b) = TFun (dual a) (dual b) +dual TInt = TInt +dual TBool = TBool +dual TDouble = TPair TDouble TDouble +dual (TArray sht t) = TArray sht (dual t) +dual TNil = TNil +dual (TPair a b) = TPair (dual a) (dual b) + +-- | Forward AD. +-- +-- Since @'Dual' (a -> b) = 'Dual' a -> 'Dual' b@, calling 'ad' on a +-- function-typed expression gives the most useful results. +ad :: Exp env a -> Exp env (Dual a) +ad = ad' ETop + +ad' :: DEnv env env' -> Exp env a -> Exp env' (Dual a) +ad' env = \case + App e1 e2 -> App (ad' env e1) (ad' env e2) + Lam t e -> Lam (dual t) (ad' (ECons t env) e) + Var t i -> + case convIdx env i of + Left freei -> App (duale t) (Var t freei) + Right duali -> Var (dual t) duali + Let e1 e2 -> Let (ad' env e1) (ad' (ECons (typeof e1) env) e2) + Lit l -> App (duale (literalType l)) (Lit l) + Cond e1 e2 e3 -> Cond (ad' env e1) (ad' env e2) (ad' env e3) + Const CAddI -> Const CAddI + Const CSubI -> Const CSubI + Const CMulI -> Const CMulI + Const CDivI -> Const CDivI + Const CAddF -> + let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero + in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) + (Pair (App (Const CAddF) (Pair (Fst (Fst v)) (Fst (Snd v)))) + (App (Const CAddF) (Pair (Snd (Fst v)) (Snd (Snd v))))) + Const CSubF -> + let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero + in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) + (Pair (App (Const CSubF) (Pair (Fst (Fst v)) (Fst (Snd v)))) + (App (Const CSubF) (Pair (Snd (Fst v)) (Snd (Snd v))))) + Const CMulF -> + let v = Var (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) Zero + in Lam (TPair (TPair TDouble TDouble) (TPair TDouble TDouble)) + (Pair (App (Const CMulF) (Pair (Fst (Fst v)) (Fst (Snd v)))) + (App (Const CAddF) (Pair + (App (Const CMulF) (Pair (Fst (Fst v)) (Snd (Snd v)))) + (App (Const CMulF) (Pair (Fst (Snd v)) (Snd (Fst v))))))) + Const _ -> undefined + Pair e1 e2 -> Pair (ad' env e1) (ad' env e2) + Fst e -> Fst (ad' env e) + Snd e -> Snd (ad' env e) + Build sht e1 e2 + | Refl <- prfDualSht sht + -> Build sht (ad' env e1) (ad' env e2) + Ifold sht e1 e2 e3 + | Refl <- prfDualSht sht + -> Ifold sht (ad' env e1) (ad' env e2) (ad' env e3) + Index e1 e2 + | TArray sht _ <- typeof e1 + , Refl <- prfDualSht sht + -> Index (ad' env e1) (ad' env e2) + Shape e + | TArray sht _ <- typeof e + , Refl <- prfDualSht sht + -> Shape (ad' env e) + +convIdx :: DEnv env env' -> Idx env a -> Either (Idx env' a) (Idx env' (Dual a)) +convIdx ETop i = Left i +convIdx (ECons _ _) Zero = Right Zero +convIdx (ECons _ env) (Succ i) = bimap Succ Succ (convIdx env i) + +duale :: Type a -> Exp env (a -> Dual a) +duale topty = Lam topty (go topty) + where + go :: forall a env. Type a -> Exp (a ': env) (Dual a) + go ty = case ty of + TInt -> ref + TBool -> ref + TDouble -> Pair ref (Lit (LDouble 0)) + TArray _ t -> L.map (Lam t (go t)) ref + TNil -> Lit LNil + TPair t1 t2 -> Pair (Let (Fst ref) (go t1)) (Let (Snd ref) (go t2)) + TFun t1 t2 -> Lam (dual t1) + (App (duale t2) + (App (sinkExp1 ref) + (App (unduale t1) + (Var (dual t1) Zero)))) + where + ref :: Exp (a ': env) a + ref = Var ty Zero + +unduale :: Type a -> Exp env (Dual a -> a) +unduale topty = Lam (dual topty) (go topty) + where + go :: forall a env. Type a -> Exp (Dual a ': env) a + go ty = case ty of + TInt -> ref + TBool -> ref + TDouble -> Fst ref + TArray _ t -> L.map (Lam (dual t) (go t)) ref + TNil -> Lit LNil + TPair t1 t2 -> Pair (Let (Fst ref) (go t1)) (Let (Snd ref) (go t2)) + TFun t1 t2 -> Lam t1 + (App (unduale t2) + (App (sinkExp1 ref) + (App (duale t1) + (Var t1 Zero)))) + where + ref :: Exp (Dual a ': env) (Dual a) + ref = Var (dual ty) Zero + +prfDualSht :: ShapeType sh -> sh :~: Dual sh +prfDualSht STZ = Refl +prfDualSht (STC sht) + | Refl <- prfDualSht sht + = Refl diff --git a/AST.hs b/AST.hs new file mode 100644 index 0000000..7e9c69c --- /dev/null +++ b/AST.hs @@ -0,0 +1,288 @@ +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +module AST where + +import Data.GADT.Compare +import Data.Type.Equality +import qualified Data.Vector as V +import Data.Vector (Vector) + + +data Exp env a where + App :: Exp env (a -> b) -> Exp env a -> Exp env b + Lam :: Type t -> Exp (t ': env) a -> Exp env (t -> a) + Var :: Type a -> Idx env a -> Exp env a + Let :: Exp env t -> Exp (t ': env) a -> Exp env a + Lit :: Literal a -> Exp env a + Cond :: Exp env Bool -> Exp env a -> Exp env a -> Exp env a + Const :: Constant a -> Exp env a + Pair :: Exp env a -> Exp env b -> Exp env (a, b) + Fst :: Exp env (a, b) -> Exp env a + Snd :: Exp env (a, b) -> Exp env b + Build :: ShapeType sh -> Exp env sh -> Exp env (sh -> a) -> Exp env (Array sh a) + Ifold :: ShapeType sh -> Exp env ((s, sh) -> s) -> Exp env s -> Exp env sh -> Exp env s + Index :: Exp env (Array sh a) -> Exp env sh -> Exp env a + Shape :: Exp env (Array sh a) -> Exp env sh + +data Constant a where + CAddI :: Constant ((Int, Int) -> Int) + CSubI :: Constant ((Int, Int) -> Int) + CMulI :: Constant ((Int, Int) -> Int) + CDivI :: Constant ((Int, Int) -> Int) + CAddF :: Constant ((Double, Double) -> Double) + CSubF :: Constant ((Double, Double) -> Double) + CMulF :: Constant ((Double, Double) -> Double) + CDivF :: Constant ((Double, Double) -> Double) + CLog :: Constant (Double -> Double) + CExp :: Constant (Double -> Double) + CtoF :: Constant (Int -> Double) + CRound :: Constant (Double -> Int) + + CLtI :: Constant ((Int, Int) -> Bool) + CLtF :: Constant ((Double, Double) -> Bool) + CEq :: Type a -> Constant ((a, a) -> Bool) + CAnd :: Constant ((Bool, Bool) -> Bool) + COr :: Constant ((Bool, Bool) -> Bool) + CNot :: Constant (Bool -> Bool) + +data Type a where + TInt :: Type Int + TBool :: Type Bool + TDouble :: Type Double + TArray :: ShapeType sh -> Type a -> Type (Array sh a) + TNil :: Type () + TPair :: Type a -> Type b -> Type (a, b) + TFun :: Type a -> Type b -> Type (a -> b) + +data Idx env a where + Zero :: Idx (a ': env) a + Succ :: Idx env a -> Idx (t ': env) a + +data Literal a where + LInt :: Int -> Literal Int + LBool :: Bool -> Literal Bool + LDouble :: Double -> Literal Double + LArray :: Array sh a -> Literal (Array sh a) + LShape :: Shape sh -> Literal sh + LNil :: Literal () + LPair :: Literal a -> Literal b -> Literal (a, b) + +data Shape sh where + Z :: Shape () + (:.) :: Int -> Shape sh -> Shape (Int, sh) + +data ShapeType sh where + STZ :: ShapeType () + STC :: ShapeType sh -> ShapeType (Int, sh) + +data Array sh a where + Array :: Shape sh -> Type a -> Vector a -> Array sh a + +deriving instance Show (Exp env a) +deriving instance Show (Constant a) +deriving instance Show (Type a) +deriving instance Show (Idx env a) +deriving instance Show (Literal a) +deriving instance Show (Shape a) +deriving instance Show (ShapeType a) + +instance Show (Array sh a) where + showsPrec p (Array sh t v) = + showParen (p > 10) $ + showString "Array " + . showsPrec 11 sh + . showsPrec 11 t + . (case typeHasShow t of + Just Has -> showsPrec 11 v + Nothing -> showString ("[_ * " ++ show (V.length v) ++ "]")) + +deriving instance Eq (Type a) +deriving instance Eq (Shape sh) +deriving instance Eq (ShapeType sh) +deriving instance Eq a => Eq (Array sh a) + +instance GEq (Exp env) where + geq (App a b) (App a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl + geq App{} _ = Nothing + geq (Lam t e) (Lam t' e') | Just Refl <- geq t t', Just Refl <- geq e e' = Just Refl + geq Lam{} _ = Nothing + geq (Var t i) (Var t' i') | Just Refl <- geq t t', Just Refl <- geq i i' = Just Refl + geq Var{} _ = Nothing + geq (Let a b) (Let a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl + geq Let{} _ = Nothing + geq (Lit l) (Lit l') | Just Refl <- geq l l' = Just Refl + geq Lit{} _ = Nothing + geq (Cond a b c) (Cond a' b' c') | Just Refl <- geq a a', Just Refl <- geq b b', Just Refl <- geq c c' = Just Refl + geq Cond{} _ = Nothing + geq (Const c) (Const c') | Just Refl <- geq c c' = Just Refl + geq Const{} _ = Nothing + geq (Pair a b) (Pair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl + geq Pair{} _ = Nothing + geq (Fst a) (Fst a') | Just Refl <- geq a a' = Just Refl + geq Fst{} _ = Nothing + geq (Snd a) (Snd a') | Just Refl <- geq a a' = Just Refl + geq Snd{} _ = Nothing + geq (Build t a b) (Build t' a' b') | Just Refl <- geq t t', Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl + geq Build{} _ = Nothing + geq (Ifold t a b c) (Ifold t' a' b' c') | Just Refl <- geq t t', Just Refl <- geq a a', Just Refl <- geq b b' , Just Refl <- geq c c' + = Just Refl + geq Ifold{} _ = Nothing + geq (Index a b) (Index a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl + geq Index{} _ = Nothing + geq (Shape a) (Shape a') | Just Refl <- geq a a' = Just Refl + geq Shape{} _ = Nothing + +instance GEq Constant where + geq CAddI CAddI = Just Refl ; geq CAddI _ = Nothing + geq CSubI CSubI = Just Refl ; geq CSubI _ = Nothing + geq CMulI CMulI = Just Refl ; geq CMulI _ = Nothing + geq CDivI CDivI = Just Refl ; geq CDivI _ = Nothing + geq CAddF CAddF = Just Refl ; geq CAddF _ = Nothing + geq CSubF CSubF = Just Refl ; geq CSubF _ = Nothing + geq CMulF CMulF = Just Refl ; geq CMulF _ = Nothing + geq CDivF CDivF = Just Refl ; geq CDivF _ = Nothing + geq CLog CLog = Just Refl ; geq CLog _ = Nothing + geq CExp CExp = Just Refl ; geq CExp _ = Nothing + geq CtoF CtoF = Just Refl ; geq CtoF _ = Nothing + geq CRound CRound = Just Refl ; geq CRound _ = Nothing + geq CLtI CLtI = Just Refl ; geq CLtI _ = Nothing + geq CLtF CLtF = Just Refl ; geq CLtF _ = Nothing + geq (CEq t) (CEq t') | Just Refl <- geq t t' = Just Refl ; geq CEq{} _ = Nothing + geq CAnd CAnd = Just Refl ; geq CAnd _ = Nothing + geq COr COr = Just Refl ; geq COr _ = Nothing + geq CNot CNot = Just Refl ; geq CNot _ = Nothing + +instance GEq Type where + geq TInt TInt = Just Refl ; geq TInt _ = Nothing + geq TBool TBool = Just Refl ; geq TBool _ = Nothing + geq TDouble TDouble = Just Refl ; geq TDouble _ = Nothing + geq (TArray sht t) (TArray sht' t') | Just Refl <- geq sht sht', Just Refl <- geq t t' = Just Refl ; geq TArray{} _ = Nothing + geq TNil TNil = Just Refl ; geq TNil _ = Nothing + geq (TPair a b) (TPair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq TPair{} _ = Nothing + geq (TFun a b) (TFun a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq TFun{} _ = Nothing + +instance GEq (Idx env) where + geq Zero Zero = Just Refl + geq (Succ i) (Succ i') | Just Refl <- geq i i' = Just Refl + geq _ _ = Nothing + +instance GEq Literal where + geq (LInt a) (LInt a') | a == a' = Just Refl ; geq LInt{} _ = Nothing + geq (LBool a) (LBool a') | a == a' = Just Refl ; geq LBool{} _ = Nothing + geq (LDouble a) (LDouble a') | a == a' = Just Refl ; geq LDouble{} _ = Nothing + geq (LArray (Array sht t v)) (LArray (Array sht' t' v')) + | Just Refl <- geq sht sht' + , Just Refl <- geq t t' + = case typeHasEq t of + Just Has | v == v' -> Just Refl + | otherwise -> Nothing + Nothing -> error "GEq Literal: Literal array of incomparable values" + geq LArray{} _ = Nothing + geq (LShape a) (LShape a') | Just Refl <- geq a a' = Just Refl ; geq LShape{} _ = Nothing + geq LNil LNil = Just Refl ; geq LNil _ = Nothing + geq (LPair a b) (LPair a' b') | Just Refl <- geq a a', Just Refl <- geq b b' = Just Refl ; geq LPair{} _ = Nothing + +instance GEq Shape where + geq Z Z = Just Refl + geq (n :. sh) (n' :. sh') | n == n', Just Refl <- geq sh sh' = Just Refl + geq _ _ = Nothing + +instance GEq ShapeType where + geq STZ STZ = Just Refl + geq (STC sht) (STC sht') | Just Refl <- geq sht sht' = Just Refl + geq _ _ = Nothing + +shapeType :: Shape sh -> ShapeType sh +shapeType Z = STZ +shapeType (_ :. sh) = STC (shapeType sh) + +shapeType' :: Shape sh -> Type sh +shapeType' Z = TNil +shapeType' (_ :. sh) = TPair TInt (shapeType' sh) + +shapeTypeType :: ShapeType sh -> Type sh +shapeTypeType STZ = TNil +shapeTypeType (STC sht) = TPair TInt (shapeTypeType sht) + +literalType :: Literal a -> Type a +literalType LInt{} = TInt +literalType LBool{} = TBool +literalType LDouble{} = TDouble +literalType (LArray (Array sh t _)) = TArray (shapeType sh) t +literalType (LShape sh) = shapeType' sh +literalType LNil{} = TNil +literalType (LPair a b) = TPair (literalType a) (literalType b) + +constType :: Constant a -> Type a +constType CAddI = TFun (TPair TInt TInt) TInt +constType CSubI = TFun (TPair TInt TInt) TInt +constType CMulI = TFun (TPair TInt TInt) TInt +constType CDivI = TFun (TPair TInt TInt) TInt +constType CAddF = TFun (TPair TDouble TDouble) TDouble +constType CSubF = TFun (TPair TDouble TDouble) TDouble +constType CMulF = TFun (TPair TDouble TDouble) TDouble +constType CDivF = TFun (TPair TDouble TDouble) TDouble +constType CLog = TFun TDouble TDouble +constType CExp = TFun TDouble TDouble +constType CtoF = TFun TInt TDouble +constType CRound = TFun TDouble TInt +constType CLtI = TFun (TPair TInt TInt) TBool +constType CLtF = TFun (TPair TDouble TDouble) TBool +constType (CEq t) = TFun (TPair t t) TBool +constType CAnd = TFun (TPair TBool TBool) TBool +constType COr = TFun (TPair TBool TBool) TBool +constType CNot = TFun TBool TBool + +typeof :: Exp env a -> Type a +typeof (App e _) = let TFun _ t = typeof e in t +typeof (Lam t e) = TFun t (typeof e) +typeof (Var t _) = t +typeof (Let _ e) = typeof e +typeof (Lit l) = literalType l +typeof (Cond _ e _) = typeof e +typeof (Const c) = constType c +typeof (Pair e1 e2) = TPair (typeof e1) (typeof e2) +typeof (Fst e) = let TPair t _ = typeof e in t +typeof (Snd e) = let TPair _ t = typeof e in t +typeof (Build sht _ e) = let TFun _ t = typeof e in TArray sht t +typeof (Ifold _ _ e _) = typeof e +typeof (Index e _) = let TArray _ t = typeof e in t +typeof (Shape e) = let TArray sht _ = typeof e in shapeTypeType sht + +data Has c a where + Has :: c a => Has c a + +typeHasShow :: Type a -> Maybe (Has Show a) +typeHasShow TInt = Just Has +typeHasShow TBool = Just Has +typeHasShow TDouble = Just Has +typeHasShow TArray{} = Just Has +typeHasShow TNil = Just Has +typeHasShow (TPair a b) + | Just Has <- typeHasShow a + , Just Has <- typeHasShow b + = Just Has + | otherwise + = Nothing +typeHasShow TFun{} = Nothing + +typeHasEq :: Type a -> Maybe (Has Eq a) +typeHasEq TInt = Just Has +typeHasEq TBool = Just Has +typeHasEq TDouble = Just Has +typeHasEq (TArray _ t) + | Just Has <- typeHasEq t + = Just Has + | otherwise + = Nothing +typeHasEq TNil = Just Has +typeHasEq (TPair a b) + | Just Has <- typeHasEq a + , Just Has <- typeHasEq b + = Just Has + | otherwise + = Nothing +typeHasEq TFun{} = Nothing diff --git a/Examples.hs b/Examples.hs new file mode 100644 index 0000000..9d9cda7 --- /dev/null +++ b/Examples.hs @@ -0,0 +1,28 @@ +module Examples where + +import AST +import qualified Language as L + + +sumSq :: Exp env (Array (Int, ()) Double -> Double) +sumSq = Lam (TArray (STC STZ) TDouble) + (L.sum (App mapSq (Var (TArray (STC STZ) TDouble) Zero))) + +mapSq :: Exp env (Array (Int, ()) Double -> Array (Int, ()) Double) +mapSq = + Lam (TArray (STC STZ) TDouble) + (L.map (Lam TDouble + (App (Const CMulF) + (Pair (Var TDouble Zero) (Var TDouble Zero)))) + (Var (TArray (STC STZ) TDouble) Zero)) + +mapSqIota :: Exp env (Array (Int, ()) Double) +mapSqIota = + L.map (Lam TDouble + (App (Const CMulF) + (Pair (Var TDouble Zero) (Var TDouble Zero)))) + (Build (STC STZ) + (Pair (Lit (LInt 5)) (Lit LNil)) + (Lam (TPair TInt TNil) + (App (Const CtoF) + (Fst (Var (TPair TInt TNil) Zero))))) diff --git a/Gradient.hs b/Gradient.hs new file mode 100644 index 0000000..57ee904 --- /dev/null +++ b/Gradient.hs @@ -0,0 +1,23 @@ +{-# LANGUAGE GADTs #-} +module Gradient where + +import AD +import AST +import qualified Language as L +import Sink + + +gradient :: Exp env (Array sh Double -> Double) -> Exp env (Array sh Double -> Array sh Double) +gradient func = + let TFun tarr@(TArray sht _) _ = typeof func + idxt = shapeTypeType sht + func' = ad func + in Lam tarr + (Build sht + (Shape (Var tarr Zero)) + (Lam idxt + (Snd (App (sinkExp2 func') + (L.zip (Var tarr (Succ Zero)) + (L.oneHot sht + (Shape (Var tarr (Succ Zero))) + (Var idxt Zero))))))) diff --git a/Language.hs b/Language.hs new file mode 100644 index 0000000..e16cf7c --- /dev/null +++ b/Language.hs @@ -0,0 +1,52 @@ +{-# LANGUAGE GADTs #-} + +{-| This module is intended to be imported qualified, perhaps as @L@. -} +module Language where + +import AST +import Sink + + +map :: Exp env (a -> b) -> Exp env (Array sh a) -> Exp env (Array sh b) +map f e = + let ty@(TArray sht _) = typeof e + sht' = shapeTypeType sht + in Let e + (Build sht (Shape (Var ty Zero)) + (Lam sht' + (App (sinkExp2 f) + (Index (Var ty (Succ Zero)) + (Var sht' Zero))))) + +sum :: Exp env (Array (Int, ()) Double) -> Exp env Double +sum e = + let ty@(TArray sht _) = typeof e + in Let e + (Ifold sht + (Lam (TPair TDouble (TPair TInt TNil)) + (App (Const CAddF) (Pair + (Fst (Var (TPair TDouble (TPair TInt TNil)) Zero)) + (Index (Var ty (Succ Zero)) (Snd (Var (TPair TDouble (TPair TInt TNil)) Zero)))))) + (Lit (LDouble 0)) + (Shape (Var ty Zero))) + +-- | The two input arrays are assumed to be the same size. +zip :: Exp env (Array sh a) -> Exp env (Array sh b) -> Exp env (Array sh (a, b)) +zip a b = + let tarr@(TArray sht _) = typeof a + idxt = shapeTypeType sht + in Let a + (Build sht + (Shape (Var tarr Zero)) + (Lam idxt + (Pair (Index (Var tarr (Succ Zero)) (Var idxt Zero)) + (Index (sinkExp2 b) (Var idxt Zero))))) + +oneHot :: ShapeType sh -> Exp env sh -> Exp env sh -> Exp env (Array sh Double) +oneHot sht sh idx = + let idxt = shapeTypeType sht + in Build sht sh + (Lam idxt + (Cond (App (Const (CEq idxt)) (Pair (Var idxt Zero) (sinkExp1 idx))) + (Lit (LDouble 1)) + (Lit (LDouble 0)))) diff --git a/Main.hs b/Main.hs new file mode 100644 index 0000000..0848f95 --- /dev/null +++ b/Main.hs @@ -0,0 +1,5 @@ +module Main where + + +main :: IO () +main = return () diff --git a/README.md b/README.md new file mode 100644 index 0000000..2d71583 --- /dev/null +++ b/README.md @@ -0,0 +1,7 @@ +# \tilde F + +Re-implementation of the language, AD transformation, and optimisation framework from: + Amir Shaikhha, Andrew Fitzgibbon, Dimitrios Vytiniotis, and Simon Peyton + Jones. 2019. Efficient Differentiable Programming in a Functional + Array-Processing Language. Proc. ACM Program. Lang. 3, ICFP, Article 97 + (August 2019), 30 pages. https://doi.org/10.1145/3341701 diff --git a/Repl.hs b/Repl.hs new file mode 100644 index 0000000..85fe7be --- /dev/null +++ b/Repl.hs @@ -0,0 +1,11 @@ +{-# OPTIONS -Wno-unused-imports #-} +module Repl where + + +import AST +import AD +import Examples +import Gradient +import qualified Language as L +import Simplify +import Sink diff --git a/Simplify.hs b/Simplify.hs new file mode 100644 index 0000000..9ceaef9 --- /dev/null +++ b/Simplify.hs @@ -0,0 +1,203 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module Simplify ( + simplify, + simplifyFix, +) where + +import Data.Bifunctor +import Data.GADT.Compare +import qualified Data.Kind as Kind +import Data.List (find) +import Data.Type.Equality + +import AST +import Sink + + +data family Info (env :: [Kind.Type]) a +-- data instance Info env Int = InfoInt +-- data instance Info env Bool = InfoBool +-- data instance Info env Double = InfoDouble +data instance Info env (Array sh t) = InfoArray (Exp env sh) +-- data instance Info env () = InfoNil +data instance Info env (a, b) = InfoPair (Info env a) (Info env b) +-- data instance Info env (a -> b) = InfoFun + +data IEnv env where + ITop :: IEnv env + ICons :: Type a -> Maybe (Info (a ': env) a) -> IEnv env -> IEnv (a ': env) + +sinkInfo1 :: Type a -> Info env a -> Info (t ': env) a +sinkInfo1 TArray{} (InfoArray e) = InfoArray (sinkExp1 e) +sinkInfo1 (TPair t1 t2) (InfoPair a b) = InfoPair (sinkInfo1 t1 a) (sinkInfo1 t2 b) +sinkInfo1 _ _ = error "Unknown info in sinkInfo1" + +iprj :: IEnv env -> Idx env a -> Maybe (Type a, Info env a) +iprj ITop _ = Nothing +iprj (ICons t m _) Zero = (t,) <$> m +iprj (ICons _ _ env) (Succ i) = (\(t, m) -> (t, sinkInfo1 t m)) <$> iprj env i + +simplifyFix :: Exp env a -> Exp env a +simplifyFix e = + let maxTimes = 4 + es = take (maxTimes + 1) (iterate simplify e) + pairs = zip es (tail es) + in case find (\(a,b) -> case geq a b of Just Refl -> True ; _ -> False) pairs of + Just (e', _) -> e' + Nothing -> error "Simplification doesn't converge!" + +simplify :: Exp env a -> Exp env a +simplify = fst . simplify' ITop + +simplify' :: IEnv env -> Exp env a -> (Exp env a, Maybe (Info env a)) +simplify' env = \case + App a b -> (simplifyApp (fst (simplify' env a)) (fst (simplify' env b)), Nothing) + Lam t e -> (Lam t (fst (simplify' (ICons t Nothing env) e)), Nothing) + Var t i -> (Var t i, snd <$> iprj env i) + Let arg e -> + let (arg', info) = simplify' env arg + env' = ICons (typeof arg) (sinkInfo1 (typeof arg) <$> info) env + in (simplifyLet arg' (fst (simplify' env' e)), Nothing) + Lit l -> (Lit l, Nothing) + Cond a b c -> + (Cond (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing) + Const c -> (Const c, Nothing) + Pair a b -> + let (a', ia) = simplify' env a + (b', ib) = simplify' env b + in (Pair a' b', InfoPair <$> ia <*> ib) + Fst e -> bimap simplifyFst (fmap (\(InfoPair i _) -> i)) (simplify' env e) + Snd e -> bimap simplifySnd (fmap (\(InfoPair _ i) -> i)) (simplify' env e) + Build sht a b -> + let a' = fst (simplify' env a) + in (Build sht a' (fst (simplify' env b)), Just (InfoArray a')) + Ifold sht a b c -> (Ifold sht (fst (simplify' env a)) (fst (simplify' env b)) (fst (simplify' env c)), Nothing) + Index a b -> (simplifyIndex (fst (simplify' env a)) (fst (simplify' env b)), Nothing) + Shape e -> + case simplify' env e of + (_, Just (InfoArray she)) -> (she, Nothing) + (e', _) -> (Shape e', Nothing) + +simplifyApp :: Exp env (a -> b) -> Exp env a -> Exp env b +simplifyApp (Const CAddI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a + b)) +simplifyApp (Const CSubI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a - b)) +simplifyApp (Const CMulI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a * b)) +simplifyApp (Const CDivI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LInt (a `div` b)) +simplifyApp (Const CAddF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a + b)) +simplifyApp (Const CSubF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a - b)) +simplifyApp (Const CMulF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a * b)) +simplifyApp (Const CDivF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LDouble (a / b)) +simplifyApp (Const CLog) (Lit (LDouble a)) = Lit (LDouble (log a)) +simplifyApp (Const CExp) (Lit (LDouble a)) = Lit (LDouble (exp a)) +simplifyApp (Const CtoF) (Lit (LInt a)) = Lit (LDouble (fromIntegral a)) +simplifyApp (Const CRound) (Lit (LDouble a)) = Lit (LInt (round a)) +simplifyApp (Const CLtI) (Pair (Lit (LInt a)) (Lit (LInt b))) = Lit (LBool (a < b)) +simplifyApp (Const CLtF) (Pair (Lit (LDouble a)) (Lit (LDouble b))) = Lit (LBool (a < b)) +simplifyApp (Const (CEq _)) (Pair a b) + | Just Refl <- geq a b + = Lit (LBool True) +simplifyApp (Const CAnd) (Pair (Lit (LBool a)) (Lit (LBool b))) = Lit (LBool (a && b)) +simplifyApp (Const COr) (Pair (Lit (LBool a)) (Lit (LBool b))) = Lit (LBool (a || b)) +simplifyApp (Const CNot) (Lit (LBool a)) = Lit (LBool (not a)) + +simplifyApp (Lam _ e) arg + | isDuplicable arg || countOcc Zero e <= 1 + = simplify (subst arg e) +simplifyApp (Lam _ e) arg = simplifyLet arg e + +simplifyApp a b = App a b + +simplifyLet :: Exp env a -> Exp (a ': env) b -> Exp env b +simplifyLet arg e + | isDuplicable arg || countOcc Zero e <= 1 + = simplify (subst arg e) +simplifyLet (Pair a b) e = + simplifyLet a $ + simplifyLet (sinkExp1 b) $ + subst' (\t -> \case Zero -> Pair (Var (typeof a) (Succ Zero)) + (Var (typeof b) Zero) + Succ i -> Var t (Succ (Succ i))) + e +simplifyLet (Cond c a b) e + | isDuplicable a && isDuplicable b + = simplifyLet c $ + (subst' (\t -> \case Zero -> Cond (Var TBool Zero) (sinkExp1 a) (sinkExp1 b) + Succ i -> Var t (Succ i)) + e) +simplifyLet a b = Let (simplify a) (simplify b) + +simplifyFst :: Exp env (a, b) -> Exp env a +simplifyFst (Pair e _) = e +simplifyFst (Let a e) = simplifyLet a (simplifyFst e) +simplifyFst e = Fst e + +simplifySnd :: Exp env (a, b) -> Exp env b +simplifySnd (Pair _ e) = e +simplifySnd (Let a e) = simplifyLet a (simplifySnd e) +simplifySnd e = Snd e + +simplifyIndex :: Exp env (Array sh a) -> Exp env sh -> Exp env a +simplifyIndex (Build _ _ f) e = simplifyApp f e +simplifyIndex a e = Index a e + +isDuplicable :: Exp env a -> Bool +isDuplicable (Lam _ e) = isDuplicable e +isDuplicable (Var _ _) = True +isDuplicable (Let a e) = isDuplicable a && isDuplicable e +isDuplicable (Lit (LInt _)) = True +isDuplicable (Lit (LBool _)) = True +isDuplicable (Lit (LDouble _)) = True +isDuplicable (Lit (LShape _)) = True +isDuplicable (Lit LNil) = True +isDuplicable (Lit (LPair l1 l2)) = isDuplicable (Lit l1) && isDuplicable (Lit l2) +isDuplicable (Const _) = True +isDuplicable (Pair a b) = isDuplicable a && isDuplicable b +isDuplicable (Fst e) = isDuplicable e +isDuplicable (Snd e) = isDuplicable e +isDuplicable _ = False + +countOcc :: Idx env t -> Exp env a -> Int +countOcc i (App a b) = countOcc i a + countOcc i b +countOcc i (Lam _ e) = countOcc (Succ i) e +countOcc i (Var _ j) + | Just Refl <- geq i j = 1 + | otherwise = 0 +countOcc i (Let a b) = countOcc i a + countOcc (Succ i) b +countOcc _ (Lit _) = 0 +countOcc i (Cond a b c) = countOcc i a + countOcc i b + countOcc i c +countOcc _ (Const _) = 0 +countOcc i (Pair a b) = countOcc i a + countOcc i b +countOcc i (Fst e) = countOcc i e +countOcc i (Snd e) = countOcc i e +countOcc i (Build _ a b) = countOcc i a + countOcc i b +countOcc i (Ifold _ a b c) = countOcc i a + countOcc i b + countOcc i c +countOcc i (Index a b) = countOcc i a + countOcc i b +countOcc i (Shape e) = countOcc i e + +subst :: Exp env t -> Exp (t ': env) a -> Exp env a +subst arg e = subst' (\t -> \case Zero -> arg ; Succ i -> Var t i) e + +subst' :: (forall t. Type t -> Idx env t -> Exp env' t) -> Exp env a -> Exp env' a +subst' f (App a b) = App (subst' f a) (subst' f b) +subst' f (Lam t e) = + Lam t (subst' (\t' -> \case Zero -> Var t' Zero ; Succ i -> sinkExp1 (f t' i)) e) +subst' f (Var t i) = f t i +subst' f (Let a b) = + Let (subst' f a) + (subst' (\t -> \case Zero -> Var t Zero ; Succ i -> sinkExp1 (f t i)) b) +subst' _ (Lit l) = Lit l +subst' f (Cond a b c) = Cond (subst' f a) (subst' f b) (subst' f c) +subst' _ (Const c) = Const c +subst' f (Pair a b) = Pair (subst' f a) (subst' f b) +subst' f (Fst e) = Fst (subst' f e) +subst' f (Snd e) = Snd (subst' f e) +subst' f (Build sht a b) = Build sht (subst' f a) (subst' f b) +subst' f (Ifold sht a b c) = Ifold sht (subst' f a) (subst' f b) (subst' f c) +subst' f (Index a b) = Index (subst' f a) (subst' f b) +subst' f (Shape e) = Shape (subst' f e) diff --git a/Sink.hs b/Sink.hs new file mode 100644 index 0000000..1368cb6 --- /dev/null +++ b/Sink.hs @@ -0,0 +1,47 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeOperators #-} +module Sink where + +import AST + + +newtype env :> env' = Weaken { (>:>) :: forall t'. Idx env t' -> Idx env' t' } + +wId :: env :> env +wId = Weaken id + +wSucc :: env :> env' -> env :> (a ': env') +wSucc (Weaken f) = Weaken (Succ . f) + +wSink :: env :> env' -> (a ': env) :> (a ': env') +wSink w = Weaken (\case Zero -> Zero + Succ i -> Succ (w >:> i)) + +(.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3 +Weaken f .> Weaken g = Weaken (f . g) + +sinkExp :: env :> env' -> Exp env a -> Exp env' a +sinkExp w = \case + App e1 e2 -> App (sinkExp w e1) (sinkExp w e2) + Lam t e -> Lam t (sinkExp (wSink w) e) + Var t i -> Var t (w >:> i) + Let e1 e2 -> Let (sinkExp w e1) (sinkExp (wSink w) e2) + Lit l -> Lit l + Cond e1 e2 e3 -> Cond (sinkExp w e1) (sinkExp w e2) (sinkExp w e3) + Const c -> Const c + Pair e1 e2 -> Pair (sinkExp w e1) (sinkExp w e2) + Fst e -> Fst (sinkExp w e) + Snd e -> Snd (sinkExp w e) + Build sht e1 e2 -> Build sht (sinkExp w e1) (sinkExp w e2) + Ifold sht e1 e2 e3 -> Ifold sht (sinkExp w e1) (sinkExp w e2) (sinkExp w e3) + Index e1 e2 -> Index (sinkExp w e1) (sinkExp w e2) + Shape e -> Shape (sinkExp w e) + +sinkExp1 :: Exp env a -> Exp (t ': env) a +sinkExp1 = sinkExp (wSucc wId) + +sinkExp2 :: Exp env a -> Exp (t1 ': t2 ': env) a +sinkExp2 = sinkExp (wSucc (wSucc wId)) diff --git a/ftilde.cabal b/ftilde.cabal new file mode 100644 index 0000000..1432547 --- /dev/null +++ b/ftilde.cabal @@ -0,0 +1,27 @@ +cabal-version: 2.0 +name: ftilde +synopsis: F tilde (following https://doi.org/10.1145/3341701) +version: 0.1.0.0 +license: MIT +author: Tom Smeding +maintainer: tom@tomsmeding.com +build-type: Simple + +executable ftilde + main-is: Main.hs + other-modules: + AST + AD + Examples + Gradient + Language + Repl + Simplify + Sink + build-depends: + base >= 4.13 && < 4.15, + vector, + some + hs-source-dirs: . + default-language: Haskell2010 + ghc-options: -Wall -O2 -threaded -- cgit v1.2.3-54-g00ecf