﻿module GBasis

open System
open System.Collections.Generic
open Value
open Codegen
open Algebra
open Polynomial
open Bool
open Tuple

type BasisKind =
    | StandardBasis
    | MinimalBasis
    | ReducedBasis

type Trace =
    abstract Level: int
    abstract Write: int -> Value<'a> -> CodeGen<unit,'w>
    abstract WriteLine: int -> Value<'a> -> CodeGen<unit,'w>

type Input<'itm,'inp> =
    abstract Process: Value<'inp> -> CodeGen<seq<'itm>,'w>

type Container<'itm,'idx,'cnt> =
    abstract Init: Value<seq<'itm>> -> CodeGen<'cnt,'w>
    abstract Add: Value<'cnt> -> Value<'itm> -> CodeGen<'idx,'w>
    abstract Get: Value<'cnt> -> Value<'idx> -> CodeGen<'itm,'w>
    abstract All: Value<'cnt> -> CodeGen<seq<'itm>,'w>
    abstract Indexes: Value<'cnt> -> CodeGen<seq<'idx>,'w>

type WorkingSet<'idx,'cnt,'ws> =
    abstract Init: Value<'cnt> -> CodeGen<'ws,'w>
    abstract Add: Value<'ws> -> Value<'cnt> * Value<'idx>*Value<'idx> -> CodeGen<unit,'w>
    abstract Del: Value<'ws> -> Value<'cnt> * Value<'idx>*Value<'idx> -> CodeGen<unit,'w>
    abstract Pick: Value<'ws> -> Value<'cnt> -> CodeGen<'idx*'idx,'w>
    abstract HasMore: Value<'ws> -> Value<'cnt> -> CodeGen<bool,'w>
    abstract All: Value<'ws> -> Value<'cnt> -> CodeGen<seq<'idx*'idx>,'w>

type ExpansionStrategy<'idx,'cnt,'ws,'exp> =
    abstract Init: Value<'cnt> * Value<'ws> -> CodeGen<'exp,'w>
    abstract Expand: Value<'exp> -> Value<'cnt> * Value<'ws> * Value<'idx> -> CodeGen<unit,'w>

type SPoly<'poly,'idx,'cnt,'sp> =
    abstract Init: Value<'cnt> -> CodeGen<'sp,'w>
    abstract σ: Value<'sp> -> Value<'cnt> * Value<'idx>*Value<'idx> -> CodeGen<'poly,'w>

type NormalRemainder<'poly,'cnt,'nr> =
    abstract Init: Value<'cnt> -> CodeGen<'nr,'w>
    abstract NR: Value<'nr> -> Value<'cnt> * Value<'poly> -> CodeGen<'poly,'w>

type ReductionStrategy<'idx,'cnt,'rs> =
    abstract Init: Value<'cnt> -> CodeGen<'rs,'w>
    abstract Reduce: Value<'rs> -> Value<'cnt> * Value<'idx> -> CodeGen<bool,'w>

type CanonicalForm<'poly,'cnt,'cf> =
    abstract Init: Value<'cnt> -> CodeGen<'cf,'w>
    abstract Canonicalize: Value<'cf> -> Value<'cnt> * Value<'poly> -> CodeGen<'poly,'w>

type Output<'itm,'out> =
    abstract Process: Value<seq<'itm>> -> CodeGen<'out,'w>


let GBSolver (DB:Trace,
              BK:BasisKind,
              PA:PolynomialAlgebra<_,_,_,_>,
              IP:Input<_,_>,
              PC:Container<_,_,_>,
              WS:WorkingSet<_,_,_>,
              ES:ExpansionStrategy<_,_,_,_>,
              SP:SPoly<_,_,_,_>,
              NR:NormalRemainder<_,_,_>,
              RS:ReductionStrategy<_,_,_>,
              CF:CanonicalForm<_,_,_>,
              OP:Output<_,_>) =

    // Generative implementation of Buchberger's algorithm
    let gen f = codegen {
        do! StateRecord("Basis Kind").Extend BK

        // Initialize main modules
        let! _ = DB.WriteLine 1 <| V"Begin initialization"
        let! ip = IP.Process f
        let! pc = PC.Init ip
        let! ws = WS.Init pc
        let! es = ES.Init (pc,ws)
        let! sp = SP.Init pc
        let! nr = NR.Init pc
        let! _ = DB.WriteLine 1 <| V"Initialization complete"

        // Loop through all pairs of polynomials
        while WS.HasMore ws pc do
            // Pick a pair of polynomials
            use! p = WS.Pick ws pc
            let! _ = DB.Write 2 <| V"Picked pair: "
            let! _ = DB.WriteLine 2 p

            // Calculate the S-Polynomial
            use! σ = SP.σ sp (pc, Fst p, Snd p)
            let! _ = DB.Write 3 <| V" - Residual = "
            let! _ = DB.WriteLine 3 σ

            // Calculate the normal remainder of the s-poly
            use! nr_σ = NR.NR nr (pc,σ)
            let! _ = DB.Write 3 <| V" - Normalized = "
            let! _ = DB.WriteLine 3 nr_σ

            // If the new polynomial did not reduce to 0
            yield! IfU (Not (PA.isZero nr_σ)) <| codegen {
                // Add new polynomial
                use! j = PC.Add pc nr_σ
                let! _ = DB.Write 2 <| V" * Adding new polynomial #"
                let! _ = DB.WriteLine 2 j

                // Expand the set of polynomial pair according to the newly added element
                yield! ES.Expand es (pc,ws,j)
            }

        let! _ = DB.WriteLine 1 <| V"Begin post-processing"
        let! res =
            if BK = StandardBasis then PC.All pc
            else codegen {
                // Initialize modules for computing reduced/minimal Groebner bases
                let! rs = RS.Init pc
                let! cf = CF.Init pc
                let! idx = PC.Indexes pc
                let! _ = DB.WriteLine 1 <| V"Reductions initiated"

                // Reduce unnecessary polynomials (compute reduced GB)
                let! idx' = Iterate Seq.Filter idx <| fun i -> codegen {
                    use! b = RS.Reduce rs (pc,i)
                    let! _ = DB.Write 2 <| V"Polynomial #"
                    let! _ = DB.Write 2 i
                    let! _ = DB.WriteLine 2 <| Control.If b (V" is redundant") (V" does not reduce")
                    yield Not b
                }

                yield! Iterate Seq.Map idx' <| fun i -> codegen {
                    let! c_i = PC.Get pc i
                    yield!
                        if BK = MinimalBasis then Return c_i
                        else codegen {
                            // Canonicalize polynomials (compute minimal GB)
                            let! p = CF.Canonicalize cf (pc,c_i)
                            let! _ = DB.Write 3 <| V" - Canonicalized = "
                            let! _ = DB.WriteLine 3 p
                            yield p
                        }
                }
            }
        // The basis output
        let! _ = DB.WriteLine 1 <| V"Returning result"
        return! OP.Process res
    }

    let res = gen |> Fun |> Generate
    //Quote.Eval res
    res


module Debug =
    let NoTrace<'a> =
        {   new Trace with
            member t.Level = 0
            member t.Write _ _ = Return Unit
            member t.WriteLine _ _ = Return Unit
        }

    let ToString e =
        try
            if e.GetType() = typeof<String>
            then sprintf "%s" (e.ToString())
            else sprintf "%A" e
        with _ -> "null"

    let print fn = function
        | E e when e.Type = typeof<String> -> Function.Apply (E fn) (E(Microsoft.FSharp.Quotations.Expr.Cast e))
        | E e -> Function.Apply (E fn) (E<@ ToString %e @>)
        | V v -> Function.Apply (E fn) (V(ToString v))

    let ConsoleTrace cap =
        {   new Trace with
            member t.Level = cap
            member t.Write lvl s =
                if lvl > cap then Return Unit
                else print <@ System.Console.Write @> s |> Return |> Prepend
            member t.WriteLine lvl s =
                if lvl > cap then Return Unit
                else print <@ System.Console.WriteLine @> s |> Return |> Prepend
        }

module Input =
    let NoInput =
        {   new Input<_,_> with
            member i.Process _ = Return (Seq.Empty())
        }

    let InBasis<'c> =
        {   new Input<'c,_> with
            member i.Process r = Return r
        }


module Output =
    let NoOutput =
        {   new Output<_,_> with
            member o.Process _ = Return Unit
        }

    let OutBasis<'c> =
        {   new Output<'c,_> with
            member o.Process c = Return c
        }

    
module Container =
    let SimpleContainer<'p when 'p: equality> =
        {   new Container<'p,_,_> with
            member sc.Init s = Let <| Return (Ref.Ref(Seq.ToList s))
            member sc.Add l p = codegen {
                yield Ref.Assign l <| List.Cons p (Ref.Deref l)
                return p
            }
            member sc.Get _ i = Return i
            member sc.All l = Return <| Seq.CastTo (Ref.Deref l)
            member sc.Indexes l = sc.All l
        }

    let ListContainer<'p> =
        {   new Container<'p,_,_> with
            member lc.Init s = Let <| Return (List.New s)
            member lc.Add l p = codegen {
                yield List.Add l p
                return Idx.sub (List.Count l) Idx.one
            }
            member lc.Get l i = Return <| List.Item l i
            member lc.All l = Return <| Seq.CastTo l
            member lc.Indexes l = codegen {
                let b = Idx.zero
                let e = Idx.sub (List.Count l) Idx.one
                return Seq.Make b e
            }
        }

module WorkingSet =
    let DirectPick<'i,'c> =
        {   new WorkingSet<'i,'c,_> with
            member s.Init _ = Let <| Return (List.Empty())
            member s.Add l (_,i,j) = Return <| List.Add l (Pair i j)
            member s.Del l (_,i,j) = Return <| Control.Ignore(List.Remove l (Pair i j))
            member s.HasMore l _ = Return <| (List.Count l) ^> Idx.zero
            member s.Pick l _ = codegen {
                use h = List.Item l Idx.zero
                yield List.RemoveAt l Idx.zero
                return h
            }
            member s.All l _ = Return <| Seq.CastTo l
    }

    let MakeComparer f1 f2 = {new IComparer<_> with member c.Compare(x,y) = match f1 x y with 0 -> f2 x y | n -> n}
    let LeastLCMPick(A:PolynomialAlgebra<_,_,_,_>,C:Container<_,_,_>) =
        {   new WorkingSet<_,_,_> with
            member s.Init _ = Let <| codegen {
                let f1 x y = A.cmp (Fst x) (Fst y)
                let f2 x y = Compare (Snd x) (Snd y)
                let cmp = BinaryOp.Lift MakeComparer <@MakeComparer@>
                return SortedSet.New <| cmp (BinaryOp.Flatten f1) (BinaryOp.Flatten f2)
            }
            member s.Add l (c,i,j) = codegen {
                let! c_i = C.Get c i
                let! c_j = C.Get c j
                let lcm = A.TM.MM.lcm (A.LM c_i) (A.LM c_j)
                let t = Pair lcm (Pair i j)
                return Control.Ignore <| SortedSet.Add l t
            }
            member s.Del l (c,i,j) = codegen {
                let! c_i = C.Get c i
                let! c_j = C.Get c j
                let lcm = A.TM.MM.lcm (A.LM c_i) (A.LM c_j)
                let t = Pair lcm (Pair i j)
                return Control.Ignore <| SortedSet.Remove l t
            }
            member s.HasMore l _ = Return <| (SortedSet.Count l) ^> Idx.zero
            member s.Pick l _ = codegen {
                use h = SortedSet.Min l
                yield Control.Ignore <| SortedSet.Remove l h
                return Snd h
            }
            member s.All l _ = Return <| Seq.Map (UnaryOp.Flatten Snd) l
        }


module ExpansionStrategy =
    let NoExpansion =
        {   new ExpansionStrategy<_,_,_,_> with
            member e.Init(_,_) = Return Unit
            member e.Expand _ (_,_,_) = Return Unit
        }

    let DirectExpand(C:Container<_,_,_>,S:WorkingSet<_,_,_>) =
        {   new ExpansionStrategy<_,_,_,_> with
            member e.Init(c,s) = Prepend <| codegen {
                let! l = C.Indexes c
                for i in Seq.AllPairs l do
                    yield! S.Add s (c, Fst i, Snd i)
            }
            member e.Expand _ (c,s,j) = codegen {
                let! l = C.Indexes c
                for i in l do
                    yield! IfU (i ^<> j) (S.Add s (c,i,j))
            }
        }

    let InTriple (A:PolynomialAlgebra<_,_,_,_>) =
        let in_triple dic ids k i = codegen {
            use t_i = Dictionary.Item dic i
            let! test = Iterate Seq.Forall ids <| fun j -> codegen {
                use t_j = Dictionary.Item dic j
                use lcm = A.TM.MM.lcm t_i t_j
                return (j ^= i) ^|| ((j ^< i) ^&& (lcm ^<> t_i)) ^|| ((j ^> i) ^&& (lcm ^<> t_j))
            }
            return (i ^<> k) ^&& test
        }
        DefineOnce "in_triple" <| Generate(Fun4 in_triple)

    let ExpandBuchbergerTriples(A:PolynomialAlgebra<_,_,_,_>,C:Container<_,'idx,_>,S:WorkingSet<_,_,_>) =
        {   new ExpansionStrategy<_,_,_,_> with
            member e.Init(c,s) = codegen {
                let! in_triple = InTriple A
                return! PrependV in_triple <| codegen {
                    let! l = C.Indexes c
                    for i in Seq.AllPairs l do
                        yield! S.Add s (c, Fst i, Snd i)
                }
            }
            member e.Expand es (c,s,k) = codegen {
                use! ids = C.Indexes c
                use dic = Dictionary.New()
                let! c_k = C.Get c k
                use p = A.LM c_k
                for i in ids do
                    let! c_i = C.Get c i
                    yield Dictionary.Add dic i (A.TM.MM.τ p (A.LM c_i))
                use a = Seq.Filter (Function.Apply3 es dic ids k) ids
                let! pairs = S.All s c
                let filter p = codegen {
                    let! t1 = Iterate Seq.Exists a <| fun i -> Return(i ^<> (Fst p))
                    let! t2 = Iterate Seq.Exists a <| fun i -> Return(i ^<> (Snd p))
                    return t1 ^&& t2
                }
                use! d = Iterate Seq.Filter (Seq.ToList pairs) filter
                for i in d do
                    yield! S.Del s (c, Fst i, Snd i)
                for i in a do
                    yield! S.Add s (c,i,k)
            }
        }


module SPoly =
    let GenericSPoly(A:PolynomialAlgebra<_,_,_,_>,C:Container<_,_,_>) =
        {   new SPoly<_,_,_,_> with
            member p.Init _ = Return Unit
            member p.σ _ (c,i,j) = codegen {
                let! c_i = C.Get c i
                let! c_j = C.Get c j
                let lm_i = A.LM c_i
                let lm_j = A.LM c_j
                let lc_i = A.LC c_i
                let lc_j = A.LC c_j
                let t1 = A.mul (A.fromTerm (A.CR.τ lc_i lc_j) (A.TM.MM.τ lm_i lm_j)) (A.RT c_i)
                let t2 = A.mul (A.fromTerm (A.CR.τ lc_j lc_i) (A.TM.MM.τ lm_j lm_i)) (A.RT c_j)
                return A.sub t1 t2
            }
        }


module NormalRemainder =
    let NoRemainder =
        {   new NormalRemainder<_,_,_> with
            member d.Init _ = Return Unit
            member d.NR _ (_,p) = Return p
        }

    let Remainder (A:PolynomialAlgebra<_,_,_,_>) =
        let remainder rec_rem c p = Generate <| codegen {
            use p' = Seq.Fold (BinaryOp.Flatten A.rem) p c
            return Control.If (p' ^<> p) (Function.Apply2 rec_rem c p') p'
        }
        DefineOnceRec "remainder" (fun rem -> BinaryOp.Flatten <| remainder rem)

    let PolyDivNR(A:PolynomialAlgebra<_,_,_,_>,C:Container<_,_,_>) =
        {   new NormalRemainder<_,_,_> with
            member d.Init _ = Remainder A
            member d.NR n (c,p) = codegen {
                let! a = C.All c
                return Function.Apply2 n a p
            }
        }


module ReductionStrategy =
    let NoReduction<'i,'c> =
        {   new ReductionStrategy<'i,'c,_> with
            member s.Init _ = Return Unit
            member s.Reduce _ (_,_) = Return False
        }

    let EliminateDivisors(A:PolynomialAlgebra<_,_,_,_>,C:Container<_,_,_>) =
        {   new ReductionStrategy<_,_,_> with
            member s.Init c = Let <| codegen {
                let! idx = C.Indexes c
                return List.New idx
            }
            member s.Reduce rs (c,i) = codegen {
                let! c_i = C.Get c i
                use lm_i = A.LM c_i
                let! loop = Iterate Seq.Exists rs (fun j ->
                    If  (j ^= i)
                        (Return False) <|
                        codegen {
                            let! c_j = C.Get c j
                            use lm_j = A.LM c_j
                            let lcm = A.TM.MM.lcm lm_i lm_j
                            return lcm ^= lm_i
                        })
                yield Control.If loop (List.Remove rs i) False
            }
        }


module CanonicalForm =
    let NoOperation<'t,'c when 't: equality> =
        {   new CanonicalForm<'t,'c,_> with
            member f.Init _ = Return Unit
            member f.Canonicalize _ (_,p) = Return p
        }

    let NormalizePoly(A:PolynomialAlgebra<_,_,_,'p>,C:Container<'p,_,'c>) =
        {   new CanonicalForm<_,'c,_> with
            member f.Init _ = Return Unit
            member f.Canonicalize _ (_,p) = codegen {
                let c = Field.Inverse A.CR (A.LC p)
                return A.scalar c p
            }
        }

    let Canonicalize (A:PolynomialAlgebra<_,_,_,_>) =
        let canonicalize rec_can remainder c p = Generate <| codegen {
            return! If (A.isZero p) (Return p) <| codegen {
                use rt = A.RT p
                let nrt = Function.Apply2 remainder c rt
                use rt' = Function.Apply2 rec_can c nrt
                return A.add (A.sub p rt) rt'
            }
        }
        codegen {
            let! remainder = NormalRemainder.Remainder A
            yield! DefineOnceRec "canonicalize" (fun rec_can -> BinaryOp.Flatten <| canonicalize rec_can remainder)
        }

    let ReducedBasis(A:PolynomialAlgebra<_,_,_,_>,C:Container<_,_,_>) =
        {   new CanonicalForm<_,_,_> with
            member f.Init _ = Canonicalize A
            member f.Canonicalize r (c,p) = codegen {
                let! a = C.All c
                use cf = Function.Apply2 r a p
                return A.scalar (Field.Inverse A.TM.CR (A.LC cf)) cf
            }
        }


let gb =
    let t = Debug.NoTrace
    let bk = ReducedBasis
    let K = Field.QF
    let mg = MonomialMonoid.Dense
    let tm = TermModule.Generic(K,mg)
    let ord = TermOrder.Dense.Lex mg
    let pa = PolynomialAlgebra.Generic(tm,ord)
    let ip = Input.InBasis
    let pc = Container.ListContainer
    let ws = WorkingSet.LeastLCMPick(pa,pc)
    let es = ExpansionStrategy.ExpandBuchbergerTriples(pa,pc,ws)
    let sp = SPoly.GenericSPoly(pa,pc)
    let nr = NormalRemainder.PolyDivNR(pa,pc)
    let rs = ReductionStrategy.EliminateDivisors(pa,pc)
    let cf = CanonicalForm.ReducedBasis(pa,pc)
    let op = Output.OutBasis
    let qgb = GBSolver(t,bk,pa,ip,pc,ws,es,sp,nr,rs,cf,op)
    GetV qgb

(*
#r "bin/Debug/Groebner.dll";;
#r "FSharp.PowerPack.dll";;
#r "FSharp.PowerPack.Linq.dll";;
fsi.AddPrinter Quote.Print;;
open Value;;
open Codegen;;
open Algebra;;
open Polynomial;;
open GBasis;;

let t = Debug.ConsoleTrace 5
let bk = ReducedBasis
let K = Field.QQ
let mg = MonomialMonoid.Dense
let tm = TermModule.Generic(K,mg)
let ord = TermOrder.Dense.Lex mg
let pa = PolynomialAlgebra.Generic(tm,ord)
let ip = Input.InBasis
let pc = Container.ListContainer
let ws = WorkingSet.DirectPick
let es = ExpansionStrategy.DirectExpand(pc,ws)
let sp = SPoly.GenericSPoly(pa,pc)
let nr = NormalRemainder.PolyDivNR(pa,pc)
let rs = ReductionStrategy.EliminateDivisors(pa,pc)
let cf = CanonicalForm.ReducedBasis(pa,pc)
let op = Output.OutBasis
let qgb = GBSolver(t,bk,pa,ip,pc,ws,es,sp,nr,rs,cf,op)
let gb = GetV qgb;;

let makepoly l = UnaryOp.AppV pa.fromTerms (l |> List.map (fun (c,l) -> BinaryOp.AppV tm.make c <| UnaryOp.AppV mg.fromLog (seq l)) |> seq)
let f1 = makepoly[1N,[2;0;0]; 1N,[0;1;0]; 1N,[0;0;1]; -1N,[0;0;0]]
let f2 = makepoly[1N,[1;0;0]; 1N,[0;2;0]; 1N,[0;0;1]; -1N,[0;0;0]]
let f3 = makepoly[1N,[1;0;0]; 1N,[0;1;0]; 1N,[0;0;1]; -1N,[0;0;0]];;

gb[f1;f2;f3];;
*)
