﻿module Rewrite


exception InvalidRewriteRule

type RewriteRule<'term> =
    abstract Applies: 'term -> bool
    abstract Apply: 'term -> 'term option

type RewriteRule<'term,'left,'right> =
    inherit RewriteRule<'term>
    abstract Left: 'left
    abstract Right: 'right

type RawRewriteRule<'l,'r> = RawRewriteRule of (obj -> 'l * 'r)
let (==>) l r = RawRewriteRule(fun _ -> l,r)
let forall (f: 'a -> RawRewriteRule<'l,'r>) = RawRewriteRule (fun a -> match f (unbox a) with RawRewriteRule f' -> f' a)


module internal helpers =
    let applyone (x,s) (rw:#RewriteRule<_>) =
        match rw.Apply x with
        | Some y -> y,true
        | _ -> x,s
    let applyall x rs = Seq.fold applyone (x,false) rs
    let rec applyallrep x rs =
        match applyall x rs with
        | x,true -> applyallrep x rs
        | x,false -> x
open helpers


let Reflexive(r:#RewriteRule<_>) =
    {   new RewriteRule<_> with
        member x.Applies _ = true
        member x.Apply t = match r.Apply t with Some _ as t' -> t' | _ -> Some t
    }

let Transitive(rs:#seq<#RewriteRule<_>>) =
    {   new RewriteRule<_> with
        member x.Applies t = rs |> Seq.exists (fun r -> r.Applies t)
        member x.Apply t = match applyall t rs with _,false -> None | t',_ -> Some t'
    }

let ReflexiveTransitiveClosure rs =
    {   new RewriteRule<_> with
        member x.Applies t = true
        member x.Apply t = applyallrep t rs |> Some
    }


module ExprRewrite =
    open System.Collections.Generic
    open Microsoft.FSharp.Quotations
    open Microsoft.FSharp.Quotations.ExprShape
    open Microsoft.FSharp.Quotations.DerivedPatterns

    let ExprRewrite(raw:Expr<RawRewriteRule<'l,'r>>) =
        let rec parse = function
        | SpecificCall <@ forall @> (_,_,[Lambdas(vs,e)]) ->
            let vs',l,r = parse e
            vs' @ List.concat vs,l,r
        | SpecificCall <@ (==>) @> (_,_,[l;r]) -> [],l,r
        | _ -> raise InvalidRewriteRule

        let rec unify varmap = function
        | ShapeVar v, x when Map.containsKey v varmap ->
            match Map.find v varmap with
            | Some m when m = x -> true,varmap
            | None -> true,Map.add v (Some x) varmap
            | _ -> false,Map.remove v varmap
        | ShapeCombination(o,xs) as e, (ShapeCombination(o',xs') as e') when xs.Length = xs'.Length ->
            let rebuilt = try RebuildShapeCombination(o',xs) with _ -> e'
            List.fold2 (fun (c,vs) x x' -> if c then unify vs (x,x') else c,vs) (e=rebuilt,varmap) xs xs'
        | ShapeLambda(v,x), ShapeLambda(v',x') when v = v' -> unify varmap (x,x')
        | ShapeVar v, ShapeVar v' when v' = v' -> true,varmap
        | _,_ -> false,varmap

        let rec substitute varmap = function
        | ShapeVar v as e ->
            match Map.tryFind v varmap with
            | Some (Some m) -> m
            | _ -> e
        | ShapeLambda(v,x) -> Expr.Lambda(v, substitute varmap x)
        | ShapeCombination(o,xs) -> RebuildShapeCombination(o, List.map (substitute varmap) xs)

        let vars,left,right = parse raw

        {   new RewriteRule<_,_,_> with
            member r.Left = Expr.Cast<'l> left
            member r.Right = Expr.Cast<'r> right

            member r.Applies x = fst(unify (Map.ofList [for v in vars -> (v,None)]) (left,x))
            member r.Apply x =
                match unify (Map.ofList [for v in vars -> (v,None)]) (left,x) with
                | true, varmap -> Some(substitute varmap right)
                | _ -> None
        }

    let MakeRewriteSystem xs =
        let makerule r =
            try r |> Expr.Cast |> ExprRewrite
            with _ -> raise InvalidRewriteRule
        let x = xs |> Seq.map makerule |> ReflexiveTransitiveClosure
        {   new RewriteRule<Expr<'t>> with
            member r.Applies t =
                let r = ref false
                let apply t = match x.Apply t with Some t' -> r := true;t' | _ -> t
                Quote.Transform apply apply t |> ignore
                !r
            member r.Apply t =
                let apply t = match x.Apply t with Some t' -> t' | _ -> t
                Quote.Transform apply apply t |> Expr.Cast |> Some
        }


module PolynomialRewrite =
    open Value
    open Codegen
    open Algebra
    open Polynomial

    let PolyRewrite(A:PolynomialAlgebra<_,_,_,_>,s,m,p) =
        let wrapseq = UnaryOp.Gen (fun a -> seq [a]) (fun a -> <@ seq[%a] @>)
        let rec rewrite t =
            let LM_t,RT_t = A.LM t,A.RT t
            if A.TM.MM.gcd LM_t m = m then
                A.TM.MM.τ m LM_t |> A.fromTerm (A.LC t) |> A.mul p |> A.add RT_t |> Some
            else if s || A.deg t |> GetV < 1 then None
            else match rewrite RT_t with
                    | Some r -> A.LT t |> wrapseq |> A.fromTerms |> A.add r |> Some
                    | _ -> None

        {   new RewriteRule<_,_,_> with
            member pr.Left = m
            member pr.Right = p
            member pr.Applies t = A.TM.MM.gcd (A.LM t) m = m
            member pr.Apply t = rewrite t
        }

    let PolyRewriteRW (A,s) rw =
        let m,p = match rw with RawRewriteRule f -> f null
        PolyRewrite(A,s,m,p)

    let PolyRewriteP (A,s) p =
        PolyRewrite(A,s,A.LM p, A.RT p |> A.scalar (A.LC p |> Field.Inverse A.CR |> A.CR.neg))

    let MakeRewriteSystem(A,s) = Seq.map (PolyRewriteP(A,s)) >> ReflexiveTransitiveClosure
