﻿module Quote

open System
open System.Reflection
open System.Collections.Generic
open Microsoft.FSharp.Reflection
open Microsoft.FSharp.Quotations
open Microsoft.FSharp.Quotations.ExprShape
open Microsoft.FSharp.Quotations.Patterns
open Microsoft.FSharp.Quotations.DerivedPatterns
open Microsoft.FSharp.Text.StructuredFormat
open Microsoft.FSharp.Text.StructuredFormat.Display
open Microsoft.FSharp.Text.StructuredFormat.LayoutOps
open Microsoft.FSharp.Linq.QuotationEvaluation


// Quotation helpers

let GetReflectedDefinition: Expr<'a> -> Expr<'a> = function
  | Lambdas(_, Call(_, MethodWithReflectedDefinition q, _)) -> q |> Expr.Cast
  | x -> x

let EvalPartial (f: Expr<'a -> 'b>) (x: Expr<'a>) : Expr<'b> =
    match GetReflectedDefinition f with
    | Lambda(v, e) -> Expr.Cast(Expr.Let(v, x, Expr.Quote e)).Eval()
    | _ -> Expr.Cast(Expr.Application(f, x))

let (|CurriedCall|_|) = function
    | Call(ctx,mi,args) -> Some (ctx,mi,args)
    | Applications(Lambdas(vars,Call(ctx,mi,args)),pars) when
        vars.Length = pars.Length && vars.Length = args.Length &&
        //List.forall2 (fun v a -> match v,a with [v1],Var a1 when v1=a1 -> true | _ -> false) vars args &&
        List.forall (fun p -> List.length p = 1) vars &&
        List.forall (fun p -> List.length p = 1) pars ->
            let vp = List.zip (List.map List.head vars) (List.map List.head pars)
            Some (ctx, mi, List.map (fun (a:Expr) -> a.Substitute (fun a -> List.fold (fun s (v,p) -> if a = v then Some p else s) None vp)) args)
    | _ -> None

let (|ValueApplication|_|) = function
    | Applications(Value(fn,ty),pars) when
        FSharpType.IsFunction ty &&
        List.forall (fun p -> List.length p = 1) pars ->
            let ms = Array.filter (fun (m:MethodInfo) -> m.Name="Invoke" && m.GetParameters().Length = pars.Length) (fn.GetType().GetMethods())
            if ms.Length = 0 then None
            else Some (ms.[0], List.concat pars)
    | _ -> None


// Quotation pretty-printer

module internal QuotationsPrinter =
    let opNameTable =
       [("[]", "op_Nil"); ("::", "op_ColonColon"); ("+", "op_Addition"); (":=", "op_ColonEquals"); ("ref", "Ref");
        ("~%%", "op_SpliceUntyped"); ("~++", "op_Increment"); ("~--", "op_Decrement"); ("-", "op_Subtraction");
        ("*", "op_Multiply"); ("**", "op_Exponentiation"); ("/", "op_Division"); ("@", "op_Append");
        ("^", "op_Concatenate"); ("%", "op_Modulus"); ("&&&", "op_BitwiseAnd"); ("|||", "op_BitwiseOr");
        ("^^^", "op_ExclusiveOr"); ("<<<", "op_LeftShift"); ("~~~", "op_LogicalNot"); (">>>", "op_RightShift");
        ("~+", "op_UnaryPlus"); ("~-", "op_UnaryNegation"); ("~&", "op_AddressOf"); ("~&&", "op_IntegerAddressOf");
        ("&&", "op_BooleanAnd"); ("||", "op_BooleanOr"); ("<=", "op_LessThanOrEqual"); ("=","op_Equality");
        ("<>","op_Inequality"); (">=", "op_GreaterThanOrEqual"); ("<", "op_LessThan"); (">", "op_GreaterThan");
        ("|>", "op_PipeRight"); ("||>", "op_PipeRight2"); ("|||>", "op_PipeRight3"); ("<|", "op_PipeLeft");
        ("<||", "op_PipeLeft2"); ("<|||", "op_PipeLeft3"); ("!", "op_Dereference"); (">>", "op_ComposeRight");
        ("<<", "op_ComposeLeft"); ("+=", "op_AdditionAssignment"); ("-=", "op_SubtractionAssignment");
        ("*=", "op_MultiplyAssignment"); ("/=", "op_DivisionAssignment"); ("..", "op_Range"); ("?", "op_Dynamic");
        ("fst", "Fst"); ("snd", "Snd"); ("ignore", "Ignore"); ("not", "Not");
        ("::", "Cons")]

    let IsOpName =
        let t = Dictionary<_,_>()
        for (s,n) in opNameTable do t.Add(n,s)
        fun (n:string) -> if t.ContainsKey n then t.[n] else n

    let typeNameTable =
       [("System.Collections.Generic", ""); ("System.Collections", ""); ("System", "");
        ("Microsoft.FSharp.Core", ""); ("Microsoft.FSharp.Collections", ""); ("Microsoft.FSharp.Control", "");
        ("Int32", "int"); ("String", "string"); ("Double", "float"); ("Single", "float32"); ("Unit", "unit");
        ("IEnumerable", "seq")]

    let rec PrintType(t:Type) =
        let mkp(s,a) =
            if s = null then ""
            else
                let s = typeNameTable |> List.fold (fun (s:string) (l,r) -> s.Replace(l,r)) s
                let s = s.Trim().Trim([|'.';'+'|])
                if s = "" then "" else s+a
        let n = (if t.Namespace = null && t.DeclaringType <> null then mkp(t.DeclaringType.Name,"+") else mkp(t.Namespace,".")) + mkp(t.Name,"")
        if t.IsGenericType then
            let a = t.GetGenericArguments() |> Array.map PrintType |> String.concat ","
            let i = n.IndexOf('`')
            let n = if i > 1 then n.Remove i else n
            n + "<" + a + ">"
        else n

    let rec quoteL(x:Expr) =
        let parensL = quoteL >> bracketL
        let argsL = List.map quoteL >> commaListL >> bracketL
        let oargsL = function [] -> emptyL | l -> argsL l
        let wordsL l = spaceListL (List.map wordL l)
        let memberL(i,m:#MemberInfo) =
            match i with Some a -> quoteL a | None -> wordL (PrintType m.DeclaringType)
            -- sepL "." ++ wordL m.Name
        match x with
        | Var v -> wordL v.Name
        | Value (null, _) -> wordL "()"
        | Value (a, _) when (a :? string) -> wordL ("\"" + a.ToString() + "\"")
        | Value (a, t) when FSharpType.IsFunction t -> memberL(None,a.GetType())
        | Value (a, _) -> objL a
        | Quote e -> wordL "<@" ++ quoteL e ++ wordL "@>"
        | Let (n, v, e) -> wordsL ["let"; n.Name; "="] -- quoteL v ++ wordL " in " -- quoteL e
        | LetRecursive ([n,v], e) -> wordsL ["let"; "rec"; n.Name; "="] -- quoteL v ++ wordL " in " -- quoteL e
        | Lambda (v, e) -> wordsL ["fun"; v.Name; "->"] -- quoteL e
        | Application (a, b) -> parensL a -- parensL b
        | CurriedCall (None, m, [p1]) when m.Name <> IsOpName m.Name -> wordL (IsOpName m.Name) ^^ parensL p1
        | CurriedCall (None, m, [p1; p2]) when m.Name <> IsOpName m.Name -> parensL p1 ^^ wordL (IsOpName m.Name) ^^ parensL p2
        | CurriedCall (i, m, l) -> memberL (i, m) -- argsL l
        | PropertyGet (Some i, p, l) when p.Name = "Item" -> quoteL i ++ sepL "." -- (List.map quoteL l |> commaListL |> squareBracketL)
        | PropertyGet (i, p, l) -> memberL (i, p) -- oargsL l
        | PropertySet (i, p, l, v) -> memberL (i, p) -- oargsL l ++ wordL "<-" -- quoteL v
        | FieldGet (i, f) -> memberL (i, f)
        | FieldSet (i, f, v) -> memberL (i, f) ++ wordL "<-" -- quoteL v
        | NewTuple l -> argsL l
        | NewRecord (t, l) -> FSharpType.GetRecordFields t |> List.ofArray |> List.zip l |> List.map (fun (e,pi) -> wordL pi.Name ^^ wordL "=" ^^ quoteL e) |> semiListL |> braceL
//        | NewArray (_, l) -> "[|" + String.concat "; " (List.map quoteL l) + "|]"
        | NewUnionCase (u, [p1]) when u.Name <> IsOpName u.Name -> wordL (IsOpName u.Name) ^^ parensL p1
        | NewUnionCase (u, [p1; p2]) when u.Name <> IsOpName u.Name -> parensL p1 ^^ wordL (IsOpName u.Name) ^^ parensL p2
        | NewUnionCase (u, l) -> wordL (PrintType u.DeclaringType) -- sepL "." ++ wordL u.Name -- oargsL l
        | NewObject (c, l) -> wordsL ["new"; PrintType c.DeclaringType] -- argsL l
        | Coerce (e, t) -> quoteL e -- wordsL [":>"; PrintType t]
//        | TypeTest (e, t) -> "(" + quoteL e + " :? " + PrintType t + ")"
//        | UnionCaseTest (e, u) -> "(" + quoteL e + " = " + u.Name + ")"
//        | TupleGet (e, i) -> "TupleGet(" + quoteL e + ", " + i.ToString() + ")"
        | IfThenElse (a, b, c) -> wordL "if" -- quoteL a ++ wordL "then" -- quoteL b ++ wordL "else" -- quoteL c
        | Sequential (a, b) -> quoteL a ++ rightL ";" ++ quoteL b
//        | ForIntegerRangeLoop (v, a, b, c) -> "for " + v.Name + " = " + quoteL a + " to " + quoteL b + " do " + quoteL c + " done"
        | WhileLoop (a, b) -> wordL "while" -- quoteL a ++ wordL "do" -- quoteL b ++ wordL "done"
//        | TryWith (a, _, _, v, b) -> "try " + quoteL a + " with " + v.Name + " -> " + quoteL b
//        | TryFinally (a, b) -> "try " + quoteL a + " finally " + quoteL b
//        | AddressOf e -> quoteL e
//        | AddressSet (a, b) -> quoteL a + " <- " + quoteL b
//        | VarSet (v, e) -> "VarSet(" + v.Name + ", " + quoteL e + ")"
//        | DefaultValue t -> "(default " + t.Name + ")"
//        | NewDelegate (t, l, e) -> x.ToString()
        | _ -> objL x

    let PrintQuote x = layout_to_string FormatOptions.Default <| wordL "<@ " -- quoteL x -- wordL " @>"

let Print q = QuotationsPrinter.PrintQuote q
//fsi.AddPrinter Print


// Quotation transformations

let rec Transform pre post e =
    let e = pre e
    let e =
        match e with
        | ShapeVar v -> Expr.Var v
        | ShapeLambda (v, x) -> Expr.Lambda (v, Transform pre post x)
        | ShapeCombination (o, x) -> RebuildShapeCombination (o, List.map (Transform pre post) x)
    post e

let ReduceCall: Expr<'t> -> Expr<'t> = function
    | CurriedCall(c,m,l) ->
        match c with
        | Some c -> Expr.Call(c,m,l) |> Expr.Cast
        | _ -> Expr.Call(m,l) |> Expr.Cast
    | e -> e


// Quote combinators

let Eval (x: Expr<'a>) = x.Eval()

let Let n (a: Expr<'v>) (f: Expr<'v> -> Expr<'e>) : Expr<'e> = 
    let v = Var(n, typeof<'v>)
    let e = v |> Expr.Var |> Expr.Cast
    Expr.Let(v,a,f e) |> Expr.Cast
let LetRec n (a: Expr<'v> -> Expr<'v>) (f: Expr<'v> -> Expr<'e>) : Expr<'e> = 
    let v = Var(n, typeof<'v>)
    let e = v |> Expr.Var |> Expr.Cast
    Expr.LetRecursive([v,(a e).Raw],f e) |> Expr.Cast
let Lambda n (f: Expr<'a> -> Expr<'b>) : Expr<'a->'b> =
    let v = Var(n, typeof<'a>)
    Expr.Lambda(v, v |> Expr.Var |> Expr.Cast |> f) |> Expr.Cast
