# 표준 라이브러리
import os
import re
import json
import logging
import functools
from typing import List, Dict, Tuple
from concurrent.futures import ThreadPoolExecutor

# 서드파티 패키지
import requests
import numpy as np
import pandas as pd
import ezodf
import pdfplumber
import jsonlines
import chromadb
from docx import Document
from sentence_transformers import SentenceTransformer, util
from odf.opendocument import load
from odf.text import P, H

# 플라스크 웹 서버 구성
from flask import Flask, render_template, request, jsonify

# ==========================
# 환경/상수
# ==========================
OLLAMA_URL = "http://localhost:11434/api/chat"
MODEL = "gemma3:27b"

# 탐색 폭/임계값
INITIAL_TOP_K = 20
FINAL_TOP_K = 10
VECTOR_THRESHOLD = 0.6
RERANK_THRESHOLD = 0.6
MIN_HITS = 2

# 폴백 시에도 항상 상위 k개만 사용
TOP_K_ALWAYS = 8

# 컨텍스트 길이 제한(문자 기준, 대략 수천 토큰)
MAX_CTX_CHARS = 8000

# ==========================
# Flask 앱/로깅
# ==========================
app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False
history: List[Dict[str, str]] = []

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("kb-chat")

# ==========================
# 임베딩/벡터DB
# ==========================
# 경량 한국어 임베딩
model = SentenceTransformer("BM-K/KoSimCSE-roberta")

# 로컬 ChromaDB
chroma_client = chromadb.PersistentClient(path="./chroma_storage")
collection = chroma_client.get_or_create_collection(
    name="knowledge",
    metadata={"hnsw:space": "cosine"}
)

# 전역 지식/캐시
knowledge_chunks: List[str] = []
query_embedding_cache: Dict[str, np.ndarray] = {}
chunk_embedding_cache: Dict[str, np.ndarray] = {}  # 로딩 시 모든 청크 임베딩 저장

# ==========================
# 파일별 텍스트 추출
# ==========================
def extract_text_from_docx(path: str) -> str:
    doc = Document(path)
    return "\n".join(p.text for p in doc.paragraphs if p.text.strip())

def extract_text_from_excel(path: str) -> str:
    df = pd.read_excel(path)
    return df.to_string(index=False)

def extract_text_from_ods(path: str) -> str:
    ezodf.config.set_table_expand_strategy('all')
    doc = ezodf.opendoc(path)
    lines = []
    for sheet in doc.sheets:
        for row in sheet.rows():
            cells = [str(cell.value or "").strip() for cell in row]
            if any(cells):
                lines.append(" ".join([c for c in cells if c]))
    return "\n".join(lines)

def extract_text_from_odt(path: str) -> str:
    text = []
    doc = load(path)
    for elem in doc.getElementsByType(H):
        t = "".join(getattr(child, 'data', '') for child in elem.childNodes if hasattr(child, 'data'))
        if t.strip():
            text.append(t.strip())
    for elem in doc.getElementsByType(P):
        t = "".join(getattr(child, 'data', '') for child in elem.childNodes if hasattr(child, 'data'))
        if t.strip():
            text.append(t.strip())
    return "\n".join(line for line in text if line.strip())

def extract_text_from_pdf(path: str) -> str:
    text = []
    with pdfplumber.open(path) as pdf:
        for page in pdf.pages:
            page_text = page.extract_text() or ""
            text.append(page_text)
    return "\n".join(text)

def extract_text_from_jsonl(path: str) -> str:
    lines = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            try:
                item = json.loads(line)
                # 자유 텍스트 필드를 최대한 수집
                parts = []
                for key in ("question", "answer", "text", "content", "body"):
                    v = item.get(key, "")
                    if isinstance(v, str) and v.strip():
                        parts.append(v.strip())
                if parts:
                    lines.append(" ".join(parts))
            except json.JSONDecodeError:
                continue
    return "\n".join(lines)

def extract_text(path: str) -> str:
    extractors = {
        ".docx": extract_text_from_docx,
        ".xlsx": extract_text_from_excel,
        ".ods": extract_text_from_ods,
        ".odt": extract_text_from_odt,
        ".pdf": extract_text_from_pdf,
        ".txt": lambda p: open(p, encoding="utf-8").read(),
        ".jsonl": extract_text_from_jsonl
    }
    for ext, extractor in extractors.items():
        if path.endswith(ext):
            return extractor(path)
    # 확장자 미지정 시 일반 텍스트로 시도
    try:
        return open(path, encoding="utf-8").read()
    except Exception:
        return ""

# ==========================
# 문장형 파싱/청크화 (도메인 규칙 없음, 자연문 유지)
# ==========================
def normalize_whitespace(s: str) -> str:
    return re.sub(r'\s+', ' ', s).strip()

def _merge_short_lines(text: str) -> str:
    """
    목록/표기 등으로 잘린 짧은 줄들을 자연문으로 이어붙임.
    - 콜론(:) 뒤 내용을 문장으로 인정
    - 8자 미만 짧은 라인이 연속되면 한 문장으로 합치기
    """
    lines = [normalize_whitespace(x) for x in re.split(r'[\r\n]+', text)]
    merged: List[str] = []
    buf: List[str] = []

    def flush_buf():
        if buf:
            merged.append(normalize_whitespace(" ".join(buf)))
            buf.clear()

    for ln in lines:
        if not ln:
            flush_buf()
            continue
        if len(ln) >= 15 or ":" in ln:
            flush_buf()
            merged.append(ln)
        else:
            buf.append(ln)
    flush_buf()
    return "\n".join([m for m in merged if m])

def split_into_sentences(text: str) -> List[str]:
    """
    한국어 문장 분리(일반): 개행/마침표/물음표/느낌표 기준 + 콜론 뒤 구문 포함.
    너무 짧은 토막은 제거.
    """
    text = _merge_short_lines(text)
    # 문장 경계
    rough = re.split(r'(?<=[\.!?])\s+|(?<=\.)|(?<=\!)|(?<=\?)|\n+', text)
    sentences = []
    for seg in rough:
        s = normalize_whitespace(seg)
        if len(s) >= 15:
            sentences.append(s)
    return sentences

def parse_structured_content(text: str) -> List[str]:
    """
    지식파일을 읽어 '자연스러운 문장' 단위로 청크화.
    """
    chunks: List[str] = []
    sentences = split_into_sentences(text)
    if sentences:
        chunks.extend(sentences)
    else:
        # 내용이 극히 짧을 때는 통문장으로 저장
        t = normalize_whitespace(text)
        if t:
            chunks.append(t)
    return chunks

@functools.lru_cache(maxsize=128)
def compute_embedding(text: str) -> np.ndarray:
    """
    KoSimCSE는 접두어 없이 원문 그대로 인코딩.
    """
    return model.encode([text])[0]

def process_file(fpath: str) -> List[str]:
    fname = os.path.basename(fpath)
    try:
        raw_text = extract_text(fpath)
        logger.info(f"📄 {fname} 텍스트 추출 완료 (길이: {len(raw_text)}자)")
        parsed_chunks = parse_structured_content(raw_text)
        logger.info(f"🧩 {fname}에서 청크 {len(parsed_chunks)}개 생성")
        return parsed_chunks
    except Exception as e:
        logger.error(f"❌ 파일 처리 오류: {fname} - {e}")
        return []

def load_knowledge_from_folder(folder: str = "./knowledge"):
    """
    - 모든 청크 생성
    - 각 청크 임베딩을 사전 계산하여 chunk_embedding_cache에 저장
    - ChromaDB에 일괄 적재
    """
    global knowledge_chunks, chunk_embedding_cache
    all_chunks: List[str] = []

    logger.info(f"\n📁 knowledge 폴더 내용 스캔 중...")
    file_paths = [os.path.join(folder, fname) for fname in os.listdir(folder)
                  if os.path.isfile(os.path.join(folder, fname))]

    if not file_paths:
        logger.warning("⚠️ knowledge 폴더에 파일이 없습니다.")
        knowledge_chunks = []
        chunk_embedding_cache = {}
        return

    with ThreadPoolExecutor(max_workers=min(8, os.cpu_count() or 4)) as executor:
        chunks_list = list(executor.map(process_file, file_paths))

    for chunks in chunks_list:
        all_chunks.extend(chunks)

    logger.info(f"✔ 총 {len(all_chunks)}개 청크 수집됨")

    # 기존 컬렉션 초기화
    all_items = collection.get()
    if all_items and "ids" in all_items and all_items["ids"]:
        collection.delete(ids=all_items["ids"])

    # 임베딩 사전 계산 & 캐시에 저장하면서 적재
    knowledge_chunks = all_chunks
    chunk_embedding_cache = {}

    batch_size = 100
    for i in range(0, len(all_chunks), batch_size):
        batch = all_chunks[i:i + batch_size]

        embeddings: List[List[float]] = []
        ids: List[str] = []

        for j, chunk in enumerate(batch):
            v = compute_embedding(chunk)           # np.ndarray
            chunk_embedding_cache[chunk] = v       # 쿼리 단계 재인코딩 방지
            embeddings.append(v.tolist())
            ids.append(f"chunk_{i + j}")

        collection.add(
            documents=batch,
            embeddings=embeddings,
            ids=ids
        )
        logger.info(f"✅ {len(batch)}개 청크 임베딩 저장 완료 ({i + 1}~{i + len(batch)}/{len(all_chunks)})")

    logger.info("✅ knowledge_chunks/임베딩 캐시 등록 완료")
    logger.info("✅ 웹 애플리케이션 초기화 완료. API 사용 준비 완료.")
    logger.info(f"✨ http://127.0.0.1:8000 에서 서비스가 시작 되었습니다.")

# ==========================
# 검색/재정렬 (코사인) + 제한/조기종료
# ==========================
def preprocess_question(query: str) -> str:
    return query.strip()

def retrieve_relevant_context_with_scores(
    query: str,
    initial_top_k: int = INITIAL_TOP_K,
    final_top_k: int = FINAL_TOP_K
) -> Tuple[str, List[float]]:
    """
    같은 임베딩 모델 기반 코사인 유사도 rerank.
    - 후보 문서 임베딩은 로딩 단계에서 이미 캐시됨(재인코딩 금지)
    - 조기종료: 고유사도 문서 다수이면 즉시 상위만 반환
    """
    if not knowledge_chunks:
        logger.warning("⚠️ knowledge_chunks 비어 있음")
        return "", []

    processed_query = preprocess_question(query)
    if processed_query in query_embedding_cache:
        q_emb = query_embedding_cache[processed_query]
    else:
        q_emb = model.encode([processed_query])[0]
        query_embedding_cache[processed_query] = q_emb

    initial = collection.query(query_embeddings=[q_emb.tolist()], n_results=initial_top_k)
    docs = initial.get("documents", [[]])[0] if initial else []
    if not docs:
        logger.warning("⚠️ 관련 문단을 찾지 못했습니다.")
        return "", []

    # 재인코딩 금지: 캐시에서 꺼내기
    doc_embs: List[np.ndarray] = [chunk_embedding_cache[d] for d in docs if d in chunk_embedding_cache]
    if not doc_embs:
        logger.warning("⚠️ 후보 임베딩 캐시가 비어 있습니다.")
        return "", []

    sims = util.cos_sim(q_emb, np.stack(doc_embs)).cpu().numpy().flatten().tolist()

    # 조기 종료: 0.80 이상이 3개 이상이면 바로 top-k 확정
    if sum(1 for s in sims if s >= 0.80) >= 3:
        pairs = sorted(zip(docs, sims), key=lambda x: x[1], reverse=True)[:final_top_k]
        final_docs = [doc for doc, _ in pairs]
        final_scores = [float(sc) for _, sc in pairs]
        for i, (doc, sc) in enumerate(pairs[:5], start=1):
            logger.info(f"🔍 Rerank(early) Top {i}: {sc:.4f} | {doc[:120]}...")
        return "\n".join(final_docs), final_scores

    pairs = sorted(zip(docs, sims), key=lambda x: x[1], reverse=True)[:final_top_k]
    final_docs = [doc for doc, _ in pairs]
    final_scores = [float(sc) for _, sc in pairs]

    for i, (doc, sc) in enumerate(pairs[:5], start=1):
        logger.info(f"🔍 Rerank Top {i}: {sc:.4f} | {doc[:120]}...")

    return "\n".join(final_docs), final_scores

def get_best_context(
    user_message: str,
    vector_threshold: float = VECTOR_THRESHOLD,
    rerank_threshold: float = RERANK_THRESHOLD,
    min_hits: int = MIN_HITS
) -> str:
    logger.info(f"🔍 컨텍스트 선택 시작 - vec≥{vector_threshold}, cos≥{rerank_threshold}, min_hits={min_hits}")

    processed_query = preprocess_question(user_message)
    if processed_query in query_embedding_cache:
        q_emb = query_embedding_cache[processed_query]
    else:
        q_emb = model.encode([processed_query])[0]
        query_embedding_cache[processed_query] = q_emb

    # 1차 벡터 검색
    vector_results = collection.query(query_embeddings=[q_emb.tolist()], n_results=INITIAL_TOP_K)
    vector_distances = vector_results.get("distances", [[]])[0] if vector_results else []
    vector_similarities = [1 - dist for dist in vector_distances] if vector_distances else []

    # 2차 코사인 rerank
    retrieved_context, rerank_scores = retrieve_relevant_context_with_scores(user_message)

    # 🔸 최고점 하나만 높아도 통과(짧은 질의/정답 대응)
    if rerank_scores and max(rerank_scores) >= 0.65:
        logger.info("✅ 최고점 기준 충족(≥0.65) - 선별 컨텍스트 사용")
        return retrieved_context[:MAX_CTX_CHARS]

    # 🔸 임계 미달이어도 항상 "상위 검색 결과"만 사용(무관한 덩어리 금지)
    ctx_topk, scores_topk = retrieve_relevant_context_with_scores(
        user_message, initial_top_k=INITIAL_TOP_K, final_top_k=TOP_K_ALWAYS
    )
    if scores_topk:
        logger.warning("🔁 기준 미달 - 상위 결과 기반 축소 컨텍스트 사용")
        return ctx_topk[:MAX_CTX_CHARS]

    # 🔸 정말 아무 것도 없으면 빈 컨텍스트
    logger.warning("🔁 기준 미달 & 후보 없음 - 빈 컨텍스트 반환")
    return ""

def retrieve_relevant_context(query: str, initial_top_k: int = INITIAL_TOP_K, final_top_k: int = FINAL_TOP_K) -> str:
    context, _ = retrieve_relevant_context_with_scores(query, initial_top_k, final_top_k)
    return context

# ==========================
# 응답 가드레일(컨텍스트 외 엔티티 차단)
# ==========================
def contains_only_from_context(answer: str, context: str) -> bool:
    """
    URL/전화번호 등 '컨텍스트 밖' 엔티티를 생성하면 False.
    필요한 패턴은 확장 가능.
    """
    url_pat = r'https?://[^\s)]+'
    phone_pat = r'\b(?:0\d{1,2}-\d{3,4}-\d{4})\b'

    for pat in [url_pat, phone_pat]:
        for m in re.findall(pat, answer):
            if m not in context:
                return False
    return True

# ==========================
# 라우팅
# ==========================
@app.route("/")
def index():
    global history
    if request.args.get("reset") == "true":
        history.clear()
    return render_template("chat.html", history=history)

@app.route("/reload", methods=["POST"])
def reload_kb():
    """
    지식파일을 다시 적재(운영 중 수동 갱신).
    """
    try:
        load_knowledge_from_folder()
        # 캐시된 질의 임베딩 초기화(선택)
        query_embedding_cache.clear()
        return jsonify({"status": "reloaded", "chunks": len(knowledge_chunks)})
    except Exception as e:
        logger.exception("지식 재적재 실패")
        return jsonify({"status": "error", "message": str(e)}), 500

@app.route("/send", methods=["POST"])
def send():
    global history
    user_message = request.get_json().get("message", "")

    history.append({"role": "user", "content": user_message})

    # 대화 히스토리 분석
    search_query = user_message
    recent_history = history[-6:] if len(history) > 6 else history
    context_hints = [m["content"] for m in recent_history if m["role"] == "user"]

    reference_keywords = ["전에", "아까", "방금", "이전", "앞서", "먼저", "무엇을", "뭘", "언제", "어떤"]
    is_referencing_previous = any(kw in user_message for kw in reference_keywords)

    if is_referencing_previous and len(context_hints) > 1:
        search_query = " ".join(context_hints[-3:])
        logger.info(f"🔗 이전 대화 참조 감지 - 확장 검색 쿼리 사용: {search_query[:100]}...")

    # 컨텍스트 검색
    context = get_best_context(
        search_query,
        vector_threshold=VECTOR_THRESHOLD,
        rerank_threshold=RERANK_THRESHOLD,
        min_hits=MIN_HITS
    )

    # 컨텍스트가 완전히 비어 있으면 LLM 호출 없이 즉시 종료
    if not context.strip():
        history.append({"role": "assistant", "content": "해당 질문에 대한 정보를 찾을 수 없습니다."})
        return jsonify({"reply": "해당 질문에 대한 정보를 찾을 수 없습니다."})

    # 디버깅용 저장
    try:
        with open("/tmp/load_data.txt", "w", encoding="utf-8") as f:
            f.write(context)
    except Exception:
        pass

    # 자연스럽게 다듬되, 사실/수치/고유명사는 그대로 유지하도록 지시
    system_prompt = f"""당신은 친절하고 정확한 한국어 어시스턴트입니다.
아래 Context(지식파일에서 나온 원문 문장들)만을 근거로, 사용자의 질문에
'자연스러운 서술형 문장'으로 답하세요.

[출력 원칙]
- 문장은 매끄럽게 다듬되, **사실/수치/고유명사**는 Context 그대로 유지하세요.
- **새로운 사실을 추측/추가/각색하지 마세요.** (Context에 없으면 쓰지 말기)
- 같은 의미의 문장이 중복되면 하나로 통합하세요.
- 불필요한 접두어, 출처/인용 표시는 쓰지 마세요.
- 두세 문장 이내로 간결하게 답변하세요. (필요할 때만 3~4문장)
- 질문 범위를 벗어난 내용은 포함하지 마세요.
- URL, 전화번호, 주소, 수치 등 고유명사는 반드시 Context에 '그대로' 존재할 때만 사용하세요. 존재하지 않으면 쓰지 마세요.
- Context에 답이 없으면: **"해당 질문에 대한 정보를 찾을 수 없습니다."** 라고만 답하세요.

[형식]
- 한국어 서술형 단락 1개(필요시 2개). 목록/표/코드는 사용하지 않습니다.

[Context 시작]
{context}
[Context 끝]"""

    messages = [{"role": "system", "content": system_prompt}]
    # 최근 대화 히스토리 추가 (최대 10턴)
    recent_messages = history[-10:] if len(history) > 10 else history
    for msg in recent_messages:
        messages.append({
            "role": msg["role"],
            "content": msg["content"]
        })

    payload = {
        "model": MODEL,
        "temperature": 0.2,
        "top_p": 0.9,
        "top_k": 40,
        "messages": messages,
        "stream": False,   # 단일 JSON 응답으로 받기
    }

    # Ollama 요청: 타임아웃 지정(연결 5초, 응답 최대 300초)
    try:
        resp = requests.post(OLLAMA_URL, json=payload, timeout=(5, 300))
        resp.raise_for_status()
        data = resp.json()  # stream=False 이므로 단일 JSON
        full_reply = data.get("message", {}).get("content", "") or ""
    except Exception as e:
        logger.exception("LLM 호출 실패")
        full_reply = f"모델 호출 중 오류가 발생했습니다: {e}"

    # 🔒 사후 검증: 컨텍스트 밖 엔티티 차단
    if not contains_only_from_context(full_reply, context):
        full_reply = "해당 질문에 대한 정보를 찾을 수 없습니다."

    history.append({"role": "assistant", "content": full_reply})

    # 히스토리 관리
    if len(history) > 24:
        history[:] = history[-24:]

    logger.info(f"💬 현재 대화 히스토리 길이: {len(history)}턴")
    return jsonify({"reply": full_reply})

@app.route("/reset", methods=["POST"])
def reset():
    history.clear()
    query_embedding_cache.clear()
    chunk_embedding_cache.clear()
    return jsonify({"status": "reset"})

# ==========================
# 앱 초기화
# ==========================
def init_app():
    load_knowledge_from_folder()
    return app

init_app()

if __name__ == "__main__":
    # 운영에선 80, 로컬 테스트시 8000 등
    app.run(host="0.0.0.0", port=80, debug=False, threaded=True)
