{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

{-# LANGUAGE AllowAmbiguousTypes #-}

{-# LANGUAGE PartialTypeSignatures #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
module AST.Weaken.Auto (
  autoWeak,
  (&.), auto, auto1,
  Layout(..),
) where

import Data.Functor.Const
import GHC.OverloadedLabels
import GHC.TypeLits
import Unsafe.Coerce (unsafeCoerce)

import AST.Weaken
import Data
import Lemmas


type family Lookup name list where
  Lookup name ('(name, x) : _) = x
  Lookup name (_ : list) = Lookup name list
  Lookup name '[] = TypeError (Text "The name '" :<>: Text name :<>: Text "' does not appear in the list.")


-- | The @withPre@ type parameter indicates whether there can be 'LPreW'
-- occurrences within this layout.
data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where
  LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments (Lookup name segments)
  -- | Pre-weaken with a weakening
  LPreW :: forall name1 name2 segments.
           SegmentName name1 -> SegmentName name2
        -> Lookup name1 segments :> Lookup name2 segments
        -> Layout True segments (Lookup name1 segments)
  (:++:) :: Layout withPre segments env1 -> Layout withPre segments env2 -> Layout withPre segments (Append env1 env2)
infixr :++:

instance (KnownSymbol name, seg ~ Lookup name segments) => IsLabel name (Layout withPre segments seg) where
  fromLabel = LSeg (symbolSing @name)

newtype SegmentName name = SegmentName (SSymbol name)
  deriving (Show)

instance (KnownSymbol name, name ~ name') => IsLabel name (SegmentName name') where
  fromLabel = SegmentName symbolSing


data SSegments (segments :: [(Symbol, [t])]) where
  SSegNil :: SSegments '[]
  SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list)

instance (KnownSymbol name, name ~ name', segs ~ '[ '(name', ts)]) => IsLabel name (SList f ts -> SSegments segs) where
  fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil

auto :: KnownListSpine list => SList (Const ()) list
auto = knownListSpine

auto1 :: SList (Const ()) '[t]
auto1 = Const () `SCons` SNil

infixr &.
(&.) :: SSegments segs1 -> SSegments segs2 -> SSegments (Append segs1 segs2)
(&.) = ssegmentsAppend
  where
    ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b)
    ssegmentsAppend SSegNil l2 = l2
    ssegmentsAppend (SSegCons name list l1) l2 = SSegCons name list (ssegmentsAppend l1 l2)


-- | If the found segment is a TopSeg, returns Nothing.
segmentLookup :: forall segments name. SSegments segments -> SSymbol name -> SList (Const ()) (Lookup name segments)
segmentLookup = \segs name -> case go segs name of
                           Just ts -> ts
                           Nothing -> error $ "Segment not found: " ++ fromSSymbol name
  where
    go :: forall segs'. SSegments segs' -> SSymbol name -> Maybe (SList (Const ()) (Lookup name segs'))
    go SSegNil _ = Nothing
    go (SSegCons n@(SSymbol @n) (ts :: SList _ ts) (sseg :: SSegments rest)) name@SSymbol =
      case sameSymbol n name of
        Just Refl ->
          case go sseg name of
            Nothing -> Just ts
            Just _ -> error $ "Duplicate segment with name " ++ fromSSymbol name
        Nothing ->
          case unsafeCoerce Refl :: (Lookup name ('(n, ts) : rest) :~: Lookup name rest) of
            Refl -> go sseg name

data LinLayout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where
  LinEnd :: LinLayout withPre segments '[]
  LinApp :: SSymbol name -> LinLayout withPre segments env
         -> LinLayout withPre segments (Append (Lookup name segments) env)
  LinAppPreW :: SSymbol name1 -> SSymbol name2
             -> Lookup name1 segments :> Lookup name2 segments
             -> LinLayout True segments env
             -> LinLayout True segments (Append (Lookup name1 segments) env)

linLayoutAppend :: LinLayout withPre segments env1 -> LinLayout withPre segments env2 -> LinLayout withPre segments (Append env1 env2)
linLayoutAppend LinEnd lin = lin
linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout _ segments env1')) (lin2 :: LinLayout _ _ env2)
  | Refl <- lemAppendAssoc @(Lookup name segments) @env1' @env2
  = LinApp name (linLayoutAppend lin1 lin2)
linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout _ segments env1')) (lin2 :: LinLayout _ _ env2)
  | Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2
  = LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2)

lineariseLayout :: Layout withPre segments env -> LinLayout withPre segments env
lineariseLayout (LSeg name :: Layout _ _ seg)
  | Refl <- lemAppendNil @seg
  = LinApp name LinEnd
lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2
lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ seg)
  | Refl <- lemAppendNil @seg
  = LinAppPreW name1 name2 w LinEnd

preWeaken :: SSegments segments -> LinLayout True segments env
          -> (forall env'. env :> env' -> LinLayout False segments env' -> r) -> r
preWeaken _ LinEnd k = k WId LinEnd
preWeaken segs (LinApp name lin) k =
  preWeaken segs lin $ \w lin' ->
    k (wCopies (segmentLookup segs name) w) (LinApp name lin')
preWeaken segs (LinAppPreW name1 name2 weak lin) k =
  preWeaken segs lin $ \w lin' ->
    k (WStack (segmentLookup segs name1) (segmentLookup segs name2) weak w) (LinApp name2 lin')

pullDown :: SSegments segments -> SSymbol name -> LinLayout False segments env
         -> r  -- Name was not found in source
         -> (forall env'. LinLayout False segments env' -> env :> Append (Lookup name segments) env' -> r)
         -> r
pullDown segs name@SSymbol linlayout kNotFound k =
  case linlayout of
    LinEnd -> kNotFound
    LinApp n'@SSymbol lin
      | Just Refl <- sameSymbol name n' -> k lin WId
      | otherwise ->
          pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ _ env') w ->
            k (LinApp n' lin') (WSwap @env' (segmentLookup segs n') (segmentLookup segs name)
                                  .> wCopies (segmentLookup segs n') w)

sortLinLayouts :: forall segments env1 env2.
                  SSegments segments
               -> LinLayout False segments env1 -> LinLayout False segments env2 -> env1 :> env2
sortLinLayouts _ LinEnd LinEnd = WId
sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2)
  | Just Refl <- sameSymbol name1 name2 = wCopies (segmentLookup segs name1) (sortLinLayouts segs tail1 tail2)
  | otherwise =
      pullDown segs name2 lin1
        (wSinks (segmentLookup segs name2) .> sortLinLayouts segs lin1 tail2)
        (\tail1' w ->
          -- We've pulled down name2 in lin1 so that it's at the head; the
          -- resulting modified tail is tail1'. Thus now we have (name2 : tail1')
          -- vs (name2 : tail2). Thus we continue sorting tail1' vs tail2, and
          -- wCopies the name2 on top of that.
          wCopies (segmentLookup segs name2) (sortLinLayouts segs tail1' tail2) .> w)
sortLinLayouts _ LinEnd LinApp{} = WClosed
sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target"

autoWeak :: forall segments env1 env2.
            SSegments segments -> Layout True segments env1 -> Layout False segments env2 -> env1 :> env2
autoWeak segs ly1 ly2 =
  preWeaken segs (lineariseLayout ly1) $ \wPreweak lin1 ->
    sortLinLayouts segs lin1 (lineariseLayout ly2) .> wPreweak