from fastapi import FastAPI
from pydantic import BaseModel
import fitz  # PyMuPDF
import re
import pytesseract
pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"

from datetime import datetime
from difflib import SequenceMatcher

app = FastAPI()

class Req(BaseModel):
    id_supplier_invoice: int
    pdf_path: str
    suppliers: list  # list of dicts from PHP: id_supplier, supplier_name, supplier_postcode_norm, supplier_vat_no, supplier_aliases, supplier_name_norm

# ---------- helpers ----------
def norm_name(s: str) -> str:
    if not s:
        return ""
    s = s.upper()
    s = re.sub(r"[\.\,\(\)\[\]\{\}]", " ", s)
    s = re.sub(r"\b(LIMITED|LTD|PLC|LLP|INC|CO|COMPANY)\b", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

def similarity(a: str, b: str) -> float:
    return SequenceMatcher(None, a, b).ratio() * 100.0

def extract_text_pdf(path: str) -> str:
    doc = fitz.open(path)
    try:
        parts = [page.get_text("text") for page in doc]
        return "\n".join(parts).strip()
    finally:
        doc.close()

def is_text_pdf(text: str) -> bool:
    # simple heuristic for MVP
    return len(text) >= 300

def find_vat_no(text: str) -> str | None:
    # UK VAT common patterns: "VAT No: GB123456789" etc.
    m = re.search(r"\bVAT\s*(?:No|Number)?\s*[:#]?\s*(GB)?\s*([0-9]{9})\b", text, re.IGNORECASE)
    if m:
        return ("GB" + m.group(2)).upper()
    return None

def find_postcode(text: str) -> str | None:
    # Better UK postcode (still not perfect, but much stronger than the earlier regex)
    t = text.upper()
    m = re.search(r"\b([A-Z]{1,2}\d{1,2}[A-Z]?)\s*(\d[A-Z]{2})\b", t)
    if m:
        return (m.group(1) + m.group(2)).replace(" ", "")
    return None

def find_invoice_number(text: str) -> str | None:
    patterns = [
        # Invoice No / Invoice Number
        r"\bInvoice\s*(?:No|Number|#)\s*[:\-]?\s*([A-Z0-9][A-Z0-9\-\/]{2,24})\b",
        r"\bInv\s*(?:No|#)\s*[:\-]?\s*([A-Z0-9][A-Z0-9\-\/]{2,24})\b",
    ]

    for p in patterns:
        m = re.search(p, text, re.IGNORECASE)
        if m:
            candidate = m.group(1).strip()

            # Must contain at least one digit
            if not re.search(r"\d", candidate):
                continue

            # Explicitly block common false positives
            if candidate.upper() in {"DATE", "INVOICE", "NUMBER"}:
                continue

            return candidate

    return None


def parse_date_any(text: str) -> str | None:
    # dd/mm/yyyy or dd-mm-yy
    m = re.search(r"\b(\d{1,2})[\/\-](\d{1,2})[\/\-](\d{2,4})\b", text)
    if m:
        d = int(m.group(1))
        mo = int(m.group(2))
        y = int(m.group(3))
        if y < 100:
            y += 2000
        try:
            dt = datetime(y, mo, d)
            if 2015 <= dt.year <= 2035:
                return dt.strftime("%Y-%m-%d")
        except:
            pass

    # "01 Feb 2026"
    m = re.search(
        r"\b(\d{1,2})\s+(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Sept|Oct|Nov|Dec)\w*\s+(\d{4})\b",
        text,
        re.IGNORECASE
    )
    if m:
        d = int(m.group(1))
        mon = m.group(2)[:3].title()
        y = int(m.group(3))
        try:
            dt = datetime.strptime(f"{d} {mon} {y}", "%d %b %Y")
            if 2015 <= dt.year <= 2035:
                return dt.strftime("%Y-%m-%d")
        except:
            pass

    return None

def find_invoice_date(text: str) -> str | None:
    label_patterns = [
        r"(Invoice\s*Date\s*[:\-]?\s*)(.+)",
        r"(Tax\s*Point\s*Date\s*[:\-]?\s*)(.+)",
        # keep generic Date last so it doesn't steal other dates
        r"(\bDate\s*[:\-]?\s*)(.+)",
    ]
    lines = text.splitlines()

    for i, line in enumerate(lines[:100]):  # usually near top
        for lp in label_patterns:
            m = re.search(lp, line, re.IGNORECASE)
            if m:
                d = parse_date_any(line)
                if d:
                    return d
                if i + 1 < len(lines):
                    d = parse_date_any(lines[i + 1])
                    if d:
                        return d

    top = "\n".join(lines[:140])
    return parse_date_any(top)

def money_to_decimal(s: str) -> float | None:
    s = s.replace(",", "").replace("£", "").replace("GBP", "").strip()
    try:
        return float(s)
    except:
        return None

def find_totals(text: str):
    net = vat = gross = None
    currency = "GBP" if ("£" in text or "GBP" in text.upper()) else None

    lines = [l.strip() for l in text.splitlines() if l.strip()]
    for line in lines[::-1][:150]:  # search from bottom-ish backwards
        if gross is None and re.search(r"\b(Total\s*Due|Amount\s*Due|Balance\s*Due|Grand\s*Total|Total)\b", line, re.IGNORECASE):
            m = re.findall(r"£?\s*\d{1,3}(?:,\d{3})*(?:\.\d{2})", line)
            if m:
                gross = money_to_decimal(m[-1])

        if vat is None and re.search(r"\bVAT\b", line, re.IGNORECASE) and not re.search(r"VAT\s*(No|Number)", line, re.IGNORECASE):
            m = re.findall(r"£?\s*\d{1,3}(?:,\d{3})*(?:\.\d{2})", line)
            if m:
                vat = money_to_decimal(m[-1])

        if net is None and re.search(r"\b(Net|Subtotal|Sub\s*Total)\b", line, re.IGNORECASE):
            m = re.findall(r"£?\s*\d{1,3}(?:,\d{3})*(?:\.\d{2})", line)
            if m:
                net = money_to_decimal(m[-1])

    if gross is not None and net is not None and vat is None:
        v = round(gross - net, 2)
        if v >= 0:
            vat = v

    return net, vat, gross, currency

def guess_supplier_name(doc: fitz.Document) -> str | None:
    page = doc[0]
    blocks = page.get_text("blocks")  # (x0, y0, x1, y1, text, block_no, block_type)

    customer_block_markers = re.compile(
        r"\b(BILL\s*TO|BILLED\s*TO|INVOICE\s*TO|DELIVER\s*TO|DELIVERY\s*ADDRESS|SHIP\s*TO|SOLD\s*TO|CUSTOMER|"
        r"ACCOUNT|CUSTOMER\s*(NO|NUMBER)|YOUR\s*REF|OUR\s*REF|PROJECT|SITE)\b",
        re.IGNORECASE
    )

    supplier_clues = re.compile(
        r"\b(VAT\s*(NO|NUMBER)?|REG(ISTRATION)?\s*NO|COMPANY\s*NO|TEL|PHONE|EMAIL|WWW\.|HTTPS?://)\b",
        re.IGNORECASE
    )

    bad_line = re.compile(
        r"\b(INVOICE|TAX\s+INVOICE|VAT\s+INVOICE|CREDIT\s+NOTE|STATEMENT|REMITTANCE)\b",
        re.IGNORECASE
    )

    TOP_LIMIT_Y = 280  # tweak if your invoices have a taller header

    candidates: list[tuple[float, str]] = []

    for b in blocks:
        x0, y0, x1, y1, txt = b[0], b[1], b[2], b[3], b[4]

        if y0 > TOP_LIMIT_Y:
            continue
        if not txt or len(txt.strip()) < 3:
            continue

        if customer_block_markers.search(txt):
            continue

        lines = [l.strip() for l in txt.splitlines() if l.strip()]
        if not lines:
            continue

        company_line = None
        for line in lines[:10]:
            if bad_line.search(line):
                continue
            if re.match(r"^\d+\s+\w+", line):  # "12 High Street"
                continue
            if re.search(r"@|www\.|https?://|\+?\d[\d\s\-]{6,}", line):
                continue
            if 3 <= len(line) <= 80:
                company_line = line
                break

        if not company_line:
            continue

        score = 0.0
        score += max(0, 240 - x0) * 0.8   # prefer left
        score += max(0, 280 - y0) * 1.2   # prefer top
        if supplier_clues.search(txt):
            score += 200

        candidates.append((score, company_line))

    if not candidates:
        return None

    candidates.sort(key=lambda t: t[0], reverse=True)
    return candidates[0][1][:255]

def best_supplier_match(suppliers, raw_name, vat_no, postcode):
    raw_norm = norm_name(raw_name or "")

    # 1) VAT match (highest confidence)
    if vat_no:
        v = vat_no.replace(" ", "").upper()
        for s in suppliers:
            sv = (s.get("supplier_vat_no") or "").replace(" ", "").upper()
            if sv and sv == v:
                return s["id_supplier"], "vat", 99.0

    # 2) Postcode match
    if postcode:
        pc = postcode.replace(" ", "").upper()
        for s in suppliers:
            sp = (s.get("supplier_postcode_norm") or "").replace(" ", "").upper()
            if sp and sp == pc:
                return s["id_supplier"], "postcode", 92.0

    # 3) Name norm exact match
    if raw_norm:
        for s in suppliers:
            if (s.get("supplier_name_norm") or "") == raw_norm:
                return s["id_supplier"], "name_norm", 90.0

    # 4) Alias match
    if raw_norm:
        for s in suppliers:
            aliases = s.get("supplier_aliases") or []
            # safety: ensure aliases is a list
            if isinstance(aliases, str) or not isinstance(aliases, list):
                aliases = []
            for a in aliases:
                if norm_name(a) == raw_norm:
                    return s["id_supplier"], "alias", 88.0

    # 5) Fuzzy match (only accept if strong)
    if raw_norm:
        best = (None, None, 0.0)
        for s in suppliers:
            cand = s.get("supplier_name_norm") or norm_name(s.get("supplier_name") or "")
            if not cand:
                continue
            score = similarity(raw_norm, cand)
            if score > best[2]:
                best = (s["id_supplier"], "fuzzy", score)

        # Only auto-link if above threshold
        if best[0] is not None and best[2] >= 88.0:
            return best

    return None, None, 0.0

# ---------- endpoint ----------
@app.post("/extract")
def extract(req: Req):
    doc = fitz.open(req.pdf_path)
    try:
        text = "\n".join([p.get_text("text") for p in doc]).strip()
        method = "text_pdf" if is_text_pdf(text) else "ocr"  # OCR later

        supplier_raw = guess_supplier_name(doc)

        vat_no = find_vat_no(text)
        postcode = find_postcode(text)
        inv_no = find_invoice_number(text)
        inv_date = find_invoice_date(text)
        net, vat, gross, currency = find_totals(text)

        id_supplier, match_method, match_score = best_supplier_match(
            req.suppliers, supplier_raw, vat_no, postcode
        )

        # confidence score (simple MVP)
        conf = 0.0
        conf += 0.25 if supplier_raw else 0.0
        conf += 0.35 if inv_no else 0.0
        conf += 0.30 if inv_date else 0.0
        conf += 0.10 if gross is not None else 0.0
        if id_supplier:
            conf += 0.10

        conf = min(conf, 0.999)
        needs_review = conf < 0.75

        return {
            "extraction_method": method,
            "raw_text": text,
            "supplier_name_raw": supplier_raw,
            "vat_no_found": vat_no,
            "postcode_found": postcode,
            "invoice_number": inv_no,
            "invoice_date": inv_date,
            "net_total": net,
            "vat_total": vat,
            "gross_total": gross,
            "currency": currency or "GBP",
            "id_supplier": id_supplier,
            "supplier_match_method": match_method,
            "supplier_match_score": match_score,
            "confidence": round(conf, 3),
            "needs_review": 1 if needs_review else 0,
            "items": []  # line items later
        }
    finally:
        doc.close()
