#!/usr/bin/env python3
"""Standalone offline lineage verifier — single file, stdlib only.

Usage:
    python verify_lineage.py path/to/lineage.json
    python verify_lineage.py path/to/lineage.json --expected-function-reference <sha256>
    cat lineage.json | python verify_lineage.py -
    python verify_lineage.py --help

Exit code 0 if the lineage verifies, non-zero otherwise.

This file is intentionally self-contained: no FastAPI, no SQLAlchemy,
no external packages. You can email it (or hand it on a USB stick) to
an auditor or regulator alongside a saved lineage JSON, and they can
replay the doctrine end-to-end without any access to the running
system.

The lineage JSON is the response shape of GET /api/admissibility/
lineage/{root_sha256}: ``{"root_sha256": str, "nodes": [...],
"edges": [...]}``. This script recomputes every claimed hash from
its claimed payload and identifies the EXACT node where the chain
breaks.

Doctrine: "Admissibility before calculation."
"""
from __future__ import annotations

import argparse
import hashlib
import json
import sys
from collections import defaultdict
from dataclasses import dataclass


# -- canonical hashing primitives (mirrors backend/app/admissibility/canonical.py) --

def canonical_json(obj) -> str:
    """Serialise to canonical JSON: sorted keys, no whitespace, UTF-8."""
    return json.dumps(
        obj,
        sort_keys=True,
        separators=(",", ":"),
        ensure_ascii=False,
        allow_nan=False,
    )


def content_hash(obj) -> str:
    """SHA-256 hex digest of an object's canonical JSON form."""
    return hashlib.sha256(canonical_json(obj).encode("utf-8")).hexdigest()


# -- verifier (mirrors backend/app/admissibility/verifier.py) --

@dataclass
class VerifyResult:
    admissible: bool
    broken_node_sha256: str | None
    reason: str


def verify_lineage(
    lineage: dict,
    *,
    expected_function_reference_sha256: str | None = None,
) -> VerifyResult:
    nodes = lineage.get("nodes", [])
    edges = lineage.get("edges", [])
    root_sha256 = lineage.get("root_sha256")

    if root_sha256 is None:
        return VerifyResult(False, None, "MISSING_ROOT_SHA256")
    if not any(n.get("sha256") == root_sha256 for n in nodes):
        return VerifyResult(False, root_sha256, "ROOT_NOT_IN_NODE_SET")

    children_by_parent: dict[str, list[tuple[int, str]]] = defaultdict(list)
    for e in edges:
        children_by_parent[e["parent_sha256"]].append((e["position"], e["child_sha256"]))
    for parent_sha, kids in children_by_parent.items():
        kids.sort()

    for n in nodes:
        claimed = n["sha256"]
        payload = n["canonical_payload"]
        recomputed = content_hash(payload)
        if recomputed != claimed:
            return VerifyResult(
                False,
                claimed,
                f"HASH_MISMATCH at kind={n['kind']}: payload hashes to {recomputed[:12]}.., not {claimed[:12]}..",
            )
        if n["kind"] == "function_node":
            payload_leaves = sorted(payload.get("input_leaf_sha256s", []))
            edge_children = sorted(c for _, c in children_by_parent.get(claimed, []))
            if payload_leaves != edge_children:
                return VerifyResult(
                    False,
                    claimed,
                    "EDGE_LEAF_MISMATCH: function_node payload leaves differ from edge children",
                )
            if expected_function_reference_sha256 is not None:
                got = payload.get("function_reference_sha256")
                if got != expected_function_reference_sha256:
                    return VerifyResult(
                        False,
                        claimed,
                        (
                            "FUNCTION_REFERENCE_DRIFT: expected "
                            f"{expected_function_reference_sha256[:12]}.., got {(got or '')[:12]}.."
                        ),
                    )
        if n["kind"] == "output_root":
            payload_function_node = payload.get("function_node_sha256")
            edge_children = [c for _, c in children_by_parent.get(claimed, [])]
            if edge_children != [payload_function_node]:
                return VerifyResult(
                    False,
                    claimed,
                    "EDGE_FUNCTION_MISMATCH: output_root payload function_node differs from edge child",
                )

    return VerifyResult(True, None, "OK")


# -- CLI --

def _summarize_lineage(lineage: dict) -> str:
    counts: dict[str, int] = defaultdict(int)
    for n in lineage.get("nodes", []):
        counts[n.get("kind", "?")] += 1
    parts = [f"{v} {k}" for k, v in sorted(counts.items())]
    return ", ".join(parts) if parts else "(empty)"


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(
        description=(
            "Verify a saved FRTB SA Calculator lineage payload offline. "
            'Doctrine: "Admissibility before calculation."'
        ),
        epilog=(
            "Exit 0 = admissible. Exit 1 = lineage refused (verifier "
            "found a mismatch). Exit 2 = bad input (file not found, "
            "JSON parse error, etc)."
        ),
    )
    parser.add_argument(
        "path",
        help='Path to the lineage JSON, or "-" to read from stdin.',
    )
    parser.add_argument(
        "--expected-function-reference",
        dest="expected_func_ref",
        default=None,
        help=(
            "Optional. If supplied, reject the lineage if the function_node's "
            "function_reference_sha256 does not match this value. Use this "
            "when the bank publishes its expected engine + regime + "
            "parameter-set hash and you want to refuse drift."
        ),
    )
    parser.add_argument(
        "--quiet",
        "-q",
        action="store_true",
        help="Print only OK / FAIL plus the broken node sha256 on failure.",
    )
    args = parser.parse_args(argv)

    try:
        if args.path == "-":
            text = sys.stdin.read()
        else:
            with open(args.path, "r", encoding="utf-8") as f:
                text = f.read()
    except OSError as exc:
        print(f"ERROR reading {args.path}: {exc}", file=sys.stderr)
        return 2

    try:
        lineage = json.loads(text)
    except json.JSONDecodeError as exc:
        print(f"ERROR parsing JSON: {exc}", file=sys.stderr)
        return 2

    result = verify_lineage(
        lineage,
        expected_function_reference_sha256=args.expected_func_ref,
    )

    if args.quiet:
        if result.admissible:
            print("OK")
        else:
            print(f"FAIL {result.broken_node_sha256 or '-'}")
        return 0 if result.admissible else 1

    bar = "=" * 78
    print(bar)
    print("FRTB SA Calculator — offline lineage verifier")
    print(bar)
    print(f"Source                : {args.path}")
    print(f"Root sha256           : {lineage.get('root_sha256', '(missing)')}")
    print(f"Node summary          : {_summarize_lineage(lineage)}")
    print(f"Edges                 : {len(lineage.get('edges', []))}")
    if args.expected_func_ref:
        print(f"Expected fn reference : {args.expected_func_ref}")
    print(bar)
    print(f"admissible            : {result.admissible}")
    print(f"reason                : {result.reason}")
    if result.broken_node_sha256:
        print(f"broken_node_sha256    : {result.broken_node_sha256}")
    print(bar)

    return 0 if result.admissible else 1


if __name__ == "__main__":
    raise SystemExit(main())
