# app.py  (UPDATED: adds Village selection + sampling grouped by Village too)
from flask import Flask, render_template, request, send_file, jsonify
import pandas as pd
import io
from openpyxl import Workbook, load_workbook
import json
from pathlib import Path
import threading

app = Flask(__name__)

DF = pd.DataFrame()
AUTO_SAMPLE = pd.DataFrame()
AUTO_APPENDED = False  # append auto UUIDs only when download is clicked

# =========================
# UUID "no-repeat" settings
# =========================
UUID_COL = "uuid"  # must match dataset column name
UUID_FILE = Path(__file__).with_name("uuid.csv")

_UUID_LOCK = threading.Lock()
USED_UUIDS: set[str] = set()
_UUID_MTIME: float | None = None


def _normalize_uuid(x) -> str | None:
    if pd.isna(x):
        return None
    s = str(x).strip()
    return s if s else None


def load_used_uuids() -> set[str]:
    if not UUID_FILE.exists():
        return set()
    try:
        u = pd.read_csv(UUID_FILE)
    except Exception:
        return set()

    col = None
    for c in u.columns:
        if str(c).strip().lower() == "uuid":
            col = c
            break
    if col is None:
        return set()

    vals = [_normalize_uuid(v) for v in u[col].tolist()]
    return set(v for v in vals if v is not None)


def refresh_used_uuids_if_changed() -> None:
    global USED_UUIDS, _UUID_MTIME
    with _UUID_LOCK:
        if not UUID_FILE.exists():
            USED_UUIDS = set()
            _UUID_MTIME = None
            return

        mtime = UUID_FILE.stat().st_mtime
        if _UUID_MTIME is None or mtime != _UUID_MTIME:
            USED_UUIDS = load_used_uuids()
            _UUID_MTIME = mtime


def append_used_uuids(new_uuids) -> None:
    cleaned = []
    for x in (new_uuids or []):
        nx = _normalize_uuid(x)
        if nx:
            cleaned.append(nx)
    if not cleaned:
        return

    with _UUID_LOCK:
        global USED_UUIDS, _UUID_MTIME

        # reload if file changed (supports manual edits during testing)
        if UUID_FILE.exists():
            mtime = UUID_FILE.stat().st_mtime
            if _UUID_MTIME is None or mtime != _UUID_MTIME:
                USED_UUIDS = load_used_uuids()
                _UUID_MTIME = mtime
        else:
            USED_UUIDS = set()
            _UUID_MTIME = None

        to_add = [u for u in cleaned if u not in USED_UUIDS]
        if not to_add:
            return

        USED_UUIDS.update(to_add)
        UUID_FILE.parent.mkdir(parents=True, exist_ok=True)

        df_add = pd.DataFrame({"uuid": to_add})
        if UUID_FILE.exists():
            df_add.to_csv(UUID_FILE, mode="a", header=False, index=False)
        else:
            df_add.to_csv(UUID_FILE, mode="w", header=True, index=False)

        _UUID_MTIME = UUID_FILE.stat().st_mtime


refresh_used_uuids_if_changed()


# =========================
# Existing helper functions
# =========================
def read_uploaded_file(file_storage) -> pd.DataFrame:
    filename = (file_storage.filename or "").lower()
    if filename.endswith(".csv"):
        return pd.read_csv(file_storage)
    if filename.endswith(".xlsx") or filename.endswith(".xls"):
        return pd.read_excel(file_storage, engine="openpyxl")
    raise ValueError("Only .csv, .xlsx, .xls files are supported.")


def create_or_append_workbook(excel_bytes: bytes | None, df_to_write: pd.DataFrame, sheet_name="Manual_Append") -> io.BytesIO:
    if excel_bytes:
        wb = load_workbook(io.BytesIO(excel_bytes))
    else:
        wb = Workbook()
        default = wb.active
        wb.remove(default)

    if sheet_name not in wb.sheetnames:
        ws = wb.create_sheet(sheet_name)
        start_row = 1
    else:
        ws = wb[sheet_name]
        start_row = ws.max_row + 1

    if df_to_write.empty:
        out = io.BytesIO()
        wb.save(out)
        out.seek(0)
        return out

    if ws.max_row == 0 or start_row == 1:
        for c, col in enumerate(df_to_write.columns, start=1):
            ws.cell(row=1, column=c, value=col)
        start_row = 2

    for r, row in enumerate(df_to_write.itertuples(index=False), start=start_row):
        for c, val in enumerate(row, start=1):
            ws.cell(row=r, column=c, value=val)

    out = io.BytesIO()
    wb.save(out)
    out.seek(0)
    return out


def sample_logic(
    data: pd.DataFrame,
    pick_random_5_uc: bool,
    default_sample_n: int,
    min_count: int,
    key_cols: list[str],
    key_df: pd.DataFrame | None = None,
    per_key_sample_map: dict[str, int] | None = None
) -> pd.DataFrame:
    """
    Generic sampler by key_cols.
    - optional: pick random 5 UC per (DISTRICT, TEHSIL) (works only if UC exists)
    - eligible groups: count >= min_count
    - sample size per group:
        - if per_key_sample_map has key -> use that
        - else use default_sample_n
      always clamp to group size
    - optional restrict to key_df exact keys
    - excludes UUIDs already in uuid.csv
    """
    if data.empty:
        return pd.DataFrame()

    work = data.copy()

    refresh_used_uuids_if_changed()

    # Exclude already-used UUIDs
    if UUID_COL in work.columns and USED_UUIDS:
        norm = work[UUID_COL].map(_normalize_uuid)
        work = work[~norm.isin(USED_UUIDS)]

    if work.empty:
        return pd.DataFrame()

    # Optional random 5 UC per (DISTRICT, TEHSIL)
    if pick_random_5_uc and "UC" in work.columns:
        kept_parts = []
        for (dist, teh), part in work.groupby(["DISTRICT", "TEHSIL"], dropna=True):
            ucs = part["UC"].dropna().unique()
            if len(ucs) == 0:
                continue
            chosen = pd.Series(ucs).sample(n=min(5, len(ucs)), random_state=None).tolist()
            kept_parts.append(part[part["UC"].isin(chosen)])
        work = pd.concat(kept_parts, ignore_index=False) if kept_parts else pd.DataFrame()

    if work.empty:
        return pd.DataFrame()

    # Optional exact key restriction (user-selected keys)
    if key_df is not None and not key_df.empty:
        work = work.merge(key_df, on=key_cols, how="inner")
        if work.empty:
            return pd.DataFrame()

    sizes = work.groupby(key_cols).size().rename("count").reset_index()
    eligible = sizes[sizes["count"] >= min_count]
    if eligible.empty:
        return pd.DataFrame()

    eligible_rows = work.merge(eligible[key_cols], on=key_cols, how="inner")

    per_key_sample_map = per_key_sample_map or {}

    sampled_parts = []
    for key_vals, g in eligible_rows.groupby(key_cols, sort=False):
        # build a stable string key using || separator
        key_str = "||".join([str(x) for x in key_vals]) if isinstance(key_vals, tuple) else str(key_vals)
        n = int(per_key_sample_map.get(key_str, default_sample_n))
        if n < 1:
            n = default_sample_n
        n = min(n, len(g))
        sampled_parts.append(g.sample(n=n, random_state=None))

    return pd.concat(sampled_parts, ignore_index=False) if sampled_parts else pd.DataFrame()


def generate_auto_sample(data: pd.DataFrame) -> pd.DataFrame:
    """
    Auto sample fixed at:
    - each District -> each Tehsil -> random 5 UCs
    - eligible groups count >= 50
    - sample 50 per group
    - group includes Village too (if present)
    """
    if data.empty:
        return pd.DataFrame()

    parts = []

    # Decide grouping: include Village if it exists
    group_cols = ["DISTRICT", "TEHSIL", "UC", "Deh"]
    if "Village" in data.columns:
        group_cols.append("Village")

    for district in data["DISTRICT"].dropna().unique():
        district_df = data[data["DISTRICT"] == district]
        for tehsil in district_df["TEHSIL"].dropna().unique():
            tehsil_df = district_df[district_df["TEHSIL"] == tehsil]

            ucs = tehsil_df["UC"].dropna().unique()
            if len(ucs) == 0:
                continue

            chosen_ucs = pd.Series(ucs).sample(n=min(5, len(ucs)), random_state=None).tolist()
            filtered = tehsil_df[tehsil_df["UC"].isin(chosen_ucs)]

            sampled = sample_logic(
                filtered,
                pick_random_5_uc=False,
                default_sample_n=50,
                min_count=50,
                key_cols=group_cols,
                key_df=None,
                per_key_sample_map=None
            )
            if not sampled.empty:
                parts.append(sampled)

    return pd.concat(parts, ignore_index=True) if parts else pd.DataFrame()


@app.route("/")
def index():
    districts = sorted(DF["DISTRICT"].dropna().unique()) if not DF.empty else []
    return render_template(
        "index.html",
        districts=districts,
        auto_count=(len(AUTO_SAMPLE) if not AUTO_SAMPLE.empty else 0),
        has_data=(not DF.empty)
    )


# ---------------- Upload dataset ----------------
@app.route("/upload", methods=["POST"])
def upload():
    global DF, AUTO_SAMPLE, AUTO_APPENDED

    if "file" not in request.files:
        return jsonify({"error": "No file uploaded."}), 400
    f = request.files["file"]
    if not f or not f.filename:
        return jsonify({"error": "Empty upload."}), 400

    try:
        DF = read_uploaded_file(f)
    except Exception as e:
        return jsonify({"error": str(e)}), 400

    AUTO_SAMPLE = generate_auto_sample(DF)
    AUTO_APPENDED = False  # DO NOT append on upload

    return jsonify({"rows": int(len(DF)), "auto_rows": int(len(AUTO_SAMPLE))})


# ---------------- Tehsil list with District context ----------------
@app.route("/get_tehsils", methods=["POST"])
def get_tehsils():
    if DF.empty:
        return jsonify({"tehsils": []})

    payload = request.get_json(force=True)
    districts = payload.get("districts", [])
    if not districts:
        return jsonify({"tehsils": []})

    sub = DF.loc[DF["DISTRICT"].isin(districts), ["DISTRICT", "TEHSIL"]].dropna().drop_duplicates()
    sub = sub.sort_values(["DISTRICT", "TEHSIL"])
    return jsonify({"tehsils": sub.to_dict(orient="records")})


# ---------------- UC list with District+Tehsil context ----------------
@app.route("/get_ucs", methods=["POST"])
def get_ucs():
    if DF.empty:
        return jsonify({"ucs": []})

    payload = request.get_json(force=True)
    tehsil_keys = payload.get("tehsil_keys", [])
    if not tehsil_keys:
        return jsonify({"ucs": []})

    split = [x.split("||", 1) for x in tehsil_keys if "||" in x]
    key_df = pd.DataFrame(split, columns=["DISTRICT", "TEHSIL"]).drop_duplicates()

    merged = DF.merge(key_df, on=["DISTRICT", "TEHSIL"], how="inner")
    ucs = merged[["DISTRICT", "TEHSIL", "UC"]].dropna().drop_duplicates().sort_values(["DISTRICT", "TEHSIL", "UC"])
    return jsonify({"ucs": ucs.to_dict(orient="records")})


# ---------------- Village list with District+Tehsil+UC context ----------------
@app.route("/get_villages", methods=["POST"])
def get_villages():
    if DF.empty:
        return jsonify({"villages": []})

    if "Village" not in DF.columns:
        return jsonify({"villages": []})

    payload = request.get_json(force=True)
    uc_keys = payload.get("uc_keys", [])  # "DISTRICT||TEHSIL||UC"
    if not uc_keys:
        return jsonify({"villages": []})

    split = [x.split("||", 2) for x in uc_keys if x.count("||") >= 2]
    key_df = pd.DataFrame(split, columns=["DISTRICT", "TEHSIL", "UC"]).drop_duplicates()

    merged = DF.merge(key_df, on=["DISTRICT", "TEHSIL", "UC"], how="inner")
    villages = merged[["DISTRICT", "TEHSIL", "UC", "Village"]].dropna().drop_duplicates()
    villages = villages.sort_values(["DISTRICT", "TEHSIL", "UC", "Village"])
    return jsonify({"villages": villages.to_dict(orient="records")})


# ---------------- Deh preview (eligible by min_count) ----------------
@app.route("/preview_dehs", methods=["POST"])
def preview_dehs():
    if DF.empty:
        return jsonify({"dehs": []})

    payload = request.get_json(force=True)
    tehsil_keys = payload.get("tehsil_keys", [])
    uc_keys = payload.get("uc_keys", [])
    village_keys = payload.get("village_keys", [])  # "DISTRICT||TEHSIL||UC||Village"
    pick_random_5 = bool(payload.get("pick_random_5_uc", False))

    min_count = int(payload.get("min_count", 50))
    if min_count < 1:
        min_count = 50

    if not tehsil_keys:
        return jsonify({"dehs": []})

    # Filter by tehsil keys
    t_split = [x.split("||", 1) for x in tehsil_keys if "||" in x]
    t_df = pd.DataFrame(t_split, columns=["DISTRICT", "TEHSIL"]).drop_duplicates()
    filtered = DF.merge(t_df, on=["DISTRICT", "TEHSIL"], how="inner")

    # Optional UC filter
    if uc_keys:
        u_split = [x.split("||", 2) for x in uc_keys if x.count("||") >= 2]
        u_df = pd.DataFrame(u_split, columns=["DISTRICT", "TEHSIL", "UC"]).drop_duplicates()
        filtered = filtered.merge(u_df, on=["DISTRICT", "TEHSIL", "UC"], how="inner")

    # Optional Village filter (only if column exists)
    if village_keys and "Village" in filtered.columns:
        v_split = [x.split("||", 3) for x in village_keys if x.count("||") >= 3]
        v_df = pd.DataFrame(v_split, columns=["DISTRICT", "TEHSIL", "UC", "Village"]).drop_duplicates()
        filtered = filtered.merge(v_df, on=["DISTRICT", "TEHSIL", "UC", "Village"], how="inner")

    if filtered.empty:
        return jsonify({"dehs": []})

    # optional random 5 UC per tehsil
    if pick_random_5:
        kept = []
        for (dist, teh), part in filtered.groupby(["DISTRICT", "TEHSIL"], dropna=True):
            vals = part["UC"].dropna().unique()
            if len(vals) == 0:
                continue
            chosen = pd.Series(vals).sample(n=min(5, len(vals))).tolist()
            kept.append(part[part["UC"].isin(chosen)])
        filtered = pd.concat(kept, ignore_index=False) if kept else pd.DataFrame()

    if filtered.empty:
        return jsonify({"dehs": []})

    # preview must match sampling: exclude used UUIDs but NEVER append
    refresh_used_uuids_if_changed()
    if UUID_COL in filtered.columns and USED_UUIDS:
        norm = filtered[UUID_COL].map(_normalize_uuid)
        filtered = filtered[~norm.isin(USED_UUIDS)]

    if filtered.empty:
        return jsonify({"dehs": []})

    # Grouping includes Village if present
    group_keys = ["DISTRICT", "TEHSIL", "UC", "Deh"]
    if "Village" in filtered.columns:
        group_keys.append("Village")

    sizes = (
        filtered.groupby(group_keys)
        .size()
        .reset_index(name="count")
        .sort_values("count", ascending=False)
    )

    sizes = sizes[sizes["count"] >= min_count]
    return jsonify({"dehs": sizes.to_dict(orient="records")})


# ---------------- Downloads ----------------
@app.route("/download_auto")
def download_auto():
    global AUTO_APPENDED

    # Append UUIDs ONLY when user downloads auto sample
    if not AUTO_APPENDED and not AUTO_SAMPLE.empty and UUID_COL in AUTO_SAMPLE.columns:
        append_used_uuids(AUTO_SAMPLE[UUID_COL].tolist())
        AUTO_APPENDED = True

    output = io.BytesIO()
    (AUTO_SAMPLE if not AUTO_SAMPLE.empty else pd.DataFrame()).to_excel(output, index=False)
    output.seek(0)
    return send_file(output, download_name="auto_sample.xlsx", as_attachment=True)


@app.route("/download_manual_file", methods=["POST"])
def download_manual_file():
    if DF.empty:
        return jsonify({"error": "Upload dataset first."}), 400

    tehsil_keys_raw = request.form.get("tehsil_keys", "")
    uc_keys_raw = request.form.get("uc_keys", "")
    village_keys_raw = request.form.get("village_keys", "")
    deh_keys_raw = request.form.get("deh_keys", "")

    pick_random_5 = request.form.get("pick_random_5_uc", "false").lower() == "true"

    min_count = int(request.form.get("min_count", "50") or "50")
    default_sample_n = int(request.form.get("default_sample_n", "50") or "50")
    if min_count < 1:
        min_count = 50
    if default_sample_n < 1:
        default_sample_n = 50

    tehsil_keys = [x for x in tehsil_keys_raw.split("##") if x]
    uc_keys = [x for x in uc_keys_raw.split("##") if x]
    village_keys = [x for x in village_keys_raw.split("##") if x]
    deh_keys = [x for x in deh_keys_raw.split("##") if x]

    if not tehsil_keys:
        return jsonify({"error": "Select at least 1 Tehsil."}), 400

    # Optional collector
    collector_bytes = None
    if "collector" in request.files and request.files["collector"].filename:
        collector_bytes = request.files["collector"].read()

    # Per-group sample map
    per_samples_raw = request.form.get("per_group_samples", "{}")
    try:
        per_group_sample_map = json.loads(per_samples_raw) if per_samples_raw else {}
    except Exception:
        per_group_sample_map = {}

    # Filter by tehsils
    t_split = [x.split("||", 1) for x in tehsil_keys if "||" in x]
    t_df = pd.DataFrame(t_split, columns=["DISTRICT", "TEHSIL"]).drop_duplicates()
    filtered = DF.merge(t_df, on=["DISTRICT", "TEHSIL"], how="inner")

    # Optional UC filter
    if uc_keys:
        u_split = [x.split("||", 2) for x in uc_keys if x.count("||") >= 2]
        u_df = pd.DataFrame(u_split, columns=["DISTRICT", "TEHSIL", "UC"]).drop_duplicates()
        filtered = filtered.merge(u_df, on=["DISTRICT", "TEHSIL", "UC"], how="inner")

    # Optional Village filter
    if village_keys and "Village" in filtered.columns:
        v_split = [x.split("||", 3) for x in village_keys if x.count("||") >= 3]
        v_df = pd.DataFrame(v_split, columns=["DISTRICT", "TEHSIL", "UC", "Village"]).drop_duplicates()
        filtered = filtered.merge(v_df, on=["DISTRICT", "TEHSIL", "UC", "Village"], how="inner")

    # Optional Deh(+Village) filter for exact chosen groups
    key_cols = ["DISTRICT", "TEHSIL", "UC", "Deh"]
    if "Village" in filtered.columns:
        key_cols.append("Village")

    key_df = None
    if deh_keys:
        # Deh keys might be 4-part or 5-part depending on Village
        if len(key_cols) == 5:
            d_split = [x.split("||", 4) for x in deh_keys if x.count("||") >= 4]
            key_df = pd.DataFrame(d_split, columns=key_cols).drop_duplicates()
        else:
            d_split = [x.split("||", 3) for x in deh_keys if x.count("||") >= 3]
            key_df = pd.DataFrame(d_split, columns=key_cols).drop_duplicates()

    manual_sample = sample_logic(
        filtered,
        pick_random_5_uc=pick_random_5,
        default_sample_n=default_sample_n,
        min_count=min_count,
        key_cols=key_cols,
        key_df=key_df,
        per_key_sample_map=per_group_sample_map
    )

    # Append UUIDs ONLY because this endpoint is the actual download
    if not manual_sample.empty and UUID_COL in manual_sample.columns:
        append_used_uuids(manual_sample[UUID_COL].tolist())

    output = create_or_append_workbook(collector_bytes, manual_sample, sheet_name="Manual_Append")
    filename = "collector_updated.xlsx" if collector_bytes else "manual_sample.xlsx"
    return send_file(output, download_name=filename, as_attachment=True)


if __name__ == "__main__":
    app.run(debug=False, use_reloader=False)
