# -*- coding: utf-8 -*-
"""
🤖 本地模型管理模块 - 使用 llama-cpp-python 驱动本地大语言模型

支持功能：
- 模型下载（支持断点续传）
- 模型加载和卸载
- 文本生成和分类
- 与现有分类逻辑集成
"""

import os
import sys
import json
import threading
import time
from pathlib import Path
from typing import Optional, Callable, List, Dict, Any
from logger import logger

# 模型存储目录
MODEL_DIR = Path.home() / ".fileneatai" / "models"
MODEL_DIR.mkdir(parents=True, exist_ok=True)

# 配置文件
CONFIG_FILE = Path.home() / ".fileneatai" / "local_model_config.json"


# ═══════════════════════════════════════════════════════════
# 预设模型列表（使用国内镜像加速）
# ═══════════════════════════════════════════════════════════
AVAILABLE_MODELS = {
    "qwen2.5-1.5b": {
        "name": "Qwen2.5-1.5B (推荐)",
        "description": "通义千问2.5 1.5B参数版，体积小速度快，适合文件分类",
        "url": "https://hf-mirror.com/Qwen/Qwen2.5-1.5B-Instruct-GGUF/resolve/main/qwen2.5-1.5b-instruct-q4_k_m.gguf",
        "filename": "qwen2.5-1.5b-instruct-q4_k_m.gguf",
        "size_mb": 1100,  # 约 1.1GB
        "memory_required_mb": 2048,  # 需要约 2GB 内存
        "context_length": 2048,
        "recommended": True
    },
    "qwen2.5-3b": {
        "name": "Qwen2.5-3B",
        "description": "通义千问2.5 3B参数版，效果更好，需要更多内存",
        "url": "https://hf-mirror.com/Qwen/Qwen2.5-3B-Instruct-GGUF/resolve/main/qwen2.5-3b-instruct-q4_k_m.gguf",
        "filename": "qwen2.5-3b-instruct-q4_k_m.gguf",
        "size_mb": 2100,  # 约 2.1GB
        "memory_required_mb": 4096,  # 需要约 4GB 内存
        "context_length": 2048,
        "recommended": False
    },
    "llama3.2-1b": {
        "name": "Llama3.2-1B",
        "description": "Meta Llama3.2 1B参数版，英文效果好",
        "url": "https://hf-mirror.com/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
        "filename": "Llama-3.2-1B-Instruct-Q4_K_M.gguf",
        "size_mb": 800,  # 约 0.8GB
        "memory_required_mb": 1536,  # 需要约 1.5GB 内存
        "context_length": 2048,
        "recommended": False
    }
}


class ModelDownloader:
    """模型下载器 - 支持断点续传和进度回调"""
    
    def __init__(self, model_id: str, 
                 progress_callback: Optional[Callable[[int, int, float], None]] = None,
                 complete_callback: Optional[Callable[[bool, str], None]] = None):
        """
        初始化下载器
        
        Args:
            model_id: 模型标识符
            progress_callback: 进度回调 (downloaded_bytes, total_bytes, speed_mbps)
            complete_callback: 完成回调 (success, message)
        """
        self.model_id = model_id
        self.model_info = AVAILABLE_MODELS.get(model_id)
        self.progress_callback = progress_callback
        self.complete_callback = complete_callback
        self._stop_flag = False
        self._thread: Optional[threading.Thread] = None
    
    def start(self):
        """开始下载"""
        if not self.model_info:
            if self.complete_callback:
                self.complete_callback(False, f"未知模型: {self.model_id}")
            return
        
        self._stop_flag = False
        self._thread = threading.Thread(target=self._download_thread, daemon=True)
        self._thread.start()
    
    def stop(self):
        """停止下载"""
        self._stop_flag = True
        if self._thread and self._thread.is_alive():
            self._thread.join(timeout=5)
    
    def _download_thread(self):
        """下载线程"""
        import requests
        
        try:
            url = self.model_info["url"]
            filename = self.model_info["filename"]
            filepath = MODEL_DIR / filename
            temp_filepath = MODEL_DIR / f"{filename}.downloading"
            
            # 检查是否已下载完成
            if filepath.exists():
                logger.info(f"模型已存在: {filepath}")
                if self.complete_callback:
                    self.complete_callback(True, str(filepath))
                return
            
            # 断点续传：检查已下载的部分
            downloaded_size = 0
            if temp_filepath.exists():
                downloaded_size = temp_filepath.stat().st_size
                logger.info(f"断点续传：已下载 {downloaded_size / 1024 / 1024:.1f} MB")
            
            # 发起请求
            headers = {}
            if downloaded_size > 0:
                headers["Range"] = f"bytes={downloaded_size}-"
            
            logger.info(f"开始下载模型: {url}")
            response = requests.get(url, headers=headers, stream=True, timeout=30)
            
            # 获取总大小
            if response.status_code == 206:  # 部分内容（断点续传）
                content_range = response.headers.get("Content-Range", "")
                total_size = int(content_range.split("/")[-1]) if "/" in content_range else 0
            elif response.status_code == 200:
                total_size = int(response.headers.get("Content-Length", 0))
                downloaded_size = 0  # 重新开始
            else:
                raise Exception(f"下载失败: HTTP {response.status_code}")
            
            if total_size == 0:
                total_size = self.model_info["size_mb"] * 1024 * 1024  # 估算
            
            # 下载
            mode = "ab" if downloaded_size > 0 else "wb"
            start_time = time.time()
            last_report_time = start_time
            last_downloaded = downloaded_size
            
            with open(temp_filepath, mode) as f:
                for chunk in response.iter_content(chunk_size=1024 * 1024):  # 1MB chunks
                    if self._stop_flag:
                        logger.info("下载已取消")
                        if self.complete_callback:
                            self.complete_callback(False, "下载已取消")
                        return
                    
                    if chunk:
                        f.write(chunk)
                        downloaded_size += len(chunk)
                        
                        # 计算速度并报告进度
                        current_time = time.time()
                        if current_time - last_report_time >= 0.5:  # 每0.5秒报告一次
                            elapsed = current_time - last_report_time
                            speed = (downloaded_size - last_downloaded) / elapsed / 1024 / 1024  # MB/s
                            
                            if self.progress_callback:
                                self.progress_callback(downloaded_size, total_size, speed)
                            
                            last_report_time = current_time
                            last_downloaded = downloaded_size
            
            # 下载完成，重命名文件
            temp_filepath.rename(filepath)
            logger.info(f"模型下载完成: {filepath}")
            
            if self.complete_callback:
                self.complete_callback(True, str(filepath))
                
        except Exception as e:
            logger.error(f"下载模型失败: {str(e)}")
            if self.complete_callback:
                self.complete_callback(False, f"下载失败: {str(e)}")


class LocalModel:
    """本地模型 - 使用 llama-cpp-python 驱动"""
    
    _instance: Optional["LocalModel"] = None
    _lock = threading.Lock()
    
    def __init__(self):
        self.model_path: Optional[str] = None
        self.llm = None
        self._loaded = False
        self._loading = False
    
    @classmethod
    def get_instance(cls) -> "LocalModel":
        """获取单例实例"""
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = cls()
        return cls._instance
    
    def is_loaded(self) -> bool:
        """检查模型是否已加载"""
        return self._loaded and self.llm is not None
    
    def is_loading(self) -> bool:
        """检查模型是否正在加载"""
        return self._loading
    
    def get_model_path(self) -> Optional[str]:
        """获取当前加载的模型路径"""
        return self.model_path if self._loaded else None
    
    def load(self, model_path: str, n_ctx: int = 2048, n_gpu_layers: int = 0,
             progress_callback: Optional[Callable[[str], None]] = None) -> bool:
        """
        加载模型
        
        Args:
            model_path: 模型文件路径
            n_ctx: 上下文长度
            n_gpu_layers: GPU 层数（0=纯CPU，-1=全部用GPU）
            progress_callback: 加载进度回调
        
        Returns:
            是否加载成功
        """
        if self._loading:
            logger.warning("模型正在加载中...")
            return False
        
        try:
            self._loading = True
            
            # 先卸载旧模型
            if self._loaded:
                self.unload()
            
            if progress_callback:
                progress_callback("正在初始化 llama-cpp-python...")
            
            # 尝试导入 llama-cpp-python
            try:
                from llama_cpp import Llama
            except ImportError as e:
                logger.error(f"llama-cpp-python 未安装: {str(e)}")
                if progress_callback:
                    progress_callback("错误: llama-cpp-python 未安装，请先安装依赖")
                return False
            
            if progress_callback:
                progress_callback(f"正在加载模型: {os.path.basename(model_path)}...")
            
            logger.info(f"开始加载本地模型: {model_path}")
            logger.info(f"配置: n_ctx={n_ctx}, n_gpu_layers={n_gpu_layers}")
            
            # 加载模型
            self.llm = Llama(
                model_path=model_path,
                n_ctx=n_ctx,
                n_gpu_layers=n_gpu_layers,
                verbose=False,  # 减少日志输出
                n_threads=4,  # 使用4个线程
            )
            
            self.model_path = model_path
            self._loaded = True
            
            if progress_callback:
                progress_callback("模型加载完成！")
            
            logger.info(f"本地模型加载成功: {model_path}")
            
            # 保存配置
            self._save_config(model_path, n_ctx, n_gpu_layers)
            
            return True
            
        except Exception as e:
            logger.error(f"加载本地模型失败: {str(e)}")
            if progress_callback:
                progress_callback(f"加载失败: {str(e)}")
            self._loaded = False
            self.llm = None
            return False
        finally:
            self._loading = False
    
    def unload(self):
        """卸载模型，释放内存"""
        if self.llm:
            logger.info("正在卸载本地模型...")
            try:
                del self.llm
            except:
                pass
            self.llm = None
        self._loaded = False
        self.model_path = None
        logger.info("本地模型已卸载")
    
    def generate(self, prompt: str, max_tokens: int = 200, temperature: float = 0.1) -> str:
        """
        生成文本
        
        Args:
            prompt: 输入提示
            max_tokens: 最大生成token数
            temperature: 温度参数
        
        Returns:
            生成的文本
        """
        if not self._loaded or self.llm is None:
            raise RuntimeError("本地模型未加载")
        
        try:
            # 使用 chat completion 格式
            response = self.llm.create_chat_completion(
                messages=[
                    {"role": "system", "content": "你是一个文件分类助手，请根据文件内容和文件名判断它应该归类到哪个文件夹。只回答文件夹名称，不要解释。"},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=max_tokens,
                temperature=temperature,
                stop=["</s>", "\n\n"]  # 停止符
            )
            
            content = response["choices"][0]["message"]["content"]
            return content.strip()
            
        except Exception as e:
            logger.error(f"本地模型生成失败: {str(e)}")
            raise
    
    def classify_file(self, file_content: str, file_name: str, folders: List[str]) -> str:
        """
        对文件进行分类
        
        Args:
            file_content: 文件内容（前500字符）
            file_name: 文件名
            folders: 可选文件夹列表
        
        Returns:
            推荐的文件夹名称
        """
        if not folders:
            return ""
        
        # 构建简洁的 prompt
        folder_list = "\n".join([f"- {f}" for f in folders[:10]])  # 最多10个文件夹
        
        prompt = f"""根据以下文件信息，从给定的文件夹中选择最合适的一个：

文件名: {file_name}
内容摘要: {file_content[:300] if file_content else '(无内容)'}

可选文件夹:
{folder_list}

请直接回答文件夹名称（只回答名称，不要其他内容）:"""
        
        try:
            result = self.generate(prompt, max_tokens=50, temperature=0.1)
            
            # 清理结果，找到匹配的文件夹
            result = result.strip().strip('"').strip("'")
            
            # 精确匹配
            for folder in folders:
                if folder.lower() == result.lower():
                    return folder
            
            # 模糊匹配
            for folder in folders:
                if folder.lower() in result.lower() or result.lower() in folder.lower():
                    return folder
            
            # 如果没有匹配，返回第一个文件夹
            logger.warning(f"本地模型返回的文件夹 '{result}' 不在列表中，使用第一个文件夹")
            return folders[0]
            
        except Exception as e:
            logger.error(f"本地模型分类失败: {str(e)}")
            return folders[0] if folders else ""
    
    def _save_config(self, model_path: str, n_ctx: int, n_gpu_layers: int):
        """保存配置"""
        try:
            config = {
                "model_path": model_path,
                "n_ctx": n_ctx,
                "n_gpu_layers": n_gpu_layers,
                "last_used": time.strftime("%Y-%m-%d %H:%M:%S")
            }
            with open(CONFIG_FILE, "w", encoding="utf-8") as f:
                json.dump(config, f, ensure_ascii=False, indent=2)
        except Exception as e:
            logger.warning(f"保存本地模型配置失败: {str(e)}")
    
    @staticmethod
    def load_config() -> Optional[Dict[str, Any]]:
        """加载配置"""
        try:
            if CONFIG_FILE.exists():
                with open(CONFIG_FILE, "r", encoding="utf-8") as f:
                    return json.load(f)
        except Exception as e:
            logger.warning(f"加载本地模型配置失败: {str(e)}")
        return None


# ═══════════════════════════════════════════════════════════
# 工具函数
# ═══════════════════════════════════════════════════════════

def get_downloaded_models() -> List[Dict[str, Any]]:
    """获取已下载的模型列表"""
    downloaded = []
    
    for model_id, info in AVAILABLE_MODELS.items():
        filepath = MODEL_DIR / info["filename"]
        if filepath.exists():
            downloaded.append({
                "id": model_id,
                "name": info["name"],
                "description": info["description"],
                "path": str(filepath),
                "size_mb": filepath.stat().st_size / 1024 / 1024
            })
    
    return downloaded


def get_available_models() -> List[Dict[str, Any]]:
    """获取所有可用模型列表（包括未下载的）"""
    models = []
    
    for model_id, info in AVAILABLE_MODELS.items():
        filepath = MODEL_DIR / info["filename"]
        models.append({
            "id": model_id,
            "name": info["name"],
            "description": info["description"],
            "size_mb": info["size_mb"],
            "memory_required_mb": info["memory_required_mb"],
            "downloaded": filepath.exists(),
            "path": str(filepath) if filepath.exists() else None,
            "recommended": info.get("recommended", False)
        })
    
    return models


def is_llama_cpp_available() -> bool:
    """检查 llama-cpp-python 是否可用"""
    try:
        from llama_cpp import Llama
        return True
    except ImportError:
        return False


def get_local_model() -> LocalModel:
    """获取本地模型单例"""
    return LocalModel.get_instance()


def classify_with_local_model(file_content: str, file_name: str, folders: List[str]) -> str:
    """使用本地模型进行分类（便捷函数）"""
    model = get_local_model()
    if not model.is_loaded():
        raise RuntimeError("本地模型未加载")
    return model.classify_file(file_content, file_name, folders)


# ═══════════════════════════════════════════════════════════
# LangChain 兼容包装器
# ═══════════════════════════════════════════════════════════

class LocalModelLLMWrapper:
    """
    LangChain 兼容的本地模型包装器
    
    可以作为 ChatOpenAI 的替代品使用，支持 invoke() 方法
    """
    
    def __init__(self, temperature: float = 0.1):
        self.temperature = temperature
        self._model = get_local_model()
    
    def invoke(self, input_text: str, **kwargs) -> str:
        """
        调用模型进行推理
        
        Args:
            input_text: 输入文本（可以是字符串或 LangChain 消息列表）
        
        Returns:
            模型响应
        """
        if not self._model.is_loaded():
            raise RuntimeError("内置本地模型未加载，请先在设置中加载模型")
        
        # 处理输入
        if isinstance(input_text, str):
            prompt = input_text
        elif isinstance(input_text, list):
            # LangChain 消息格式
            prompt = self._format_messages(input_text)
        else:
            prompt = str(input_text)
        
        # 调用模型
        response = self._model.generate(
            prompt, 
            max_tokens=200, 
            temperature=self.temperature
        )
        
        return response
    
    def _format_messages(self, messages: list) -> str:
        """将 LangChain 消息列表格式化为文本"""
        formatted = []
        for msg in messages:
            if hasattr(msg, 'content'):
                content = msg.content
            elif isinstance(msg, dict):
                content = msg.get('content', '')
            else:
                content = str(msg)
            formatted.append(content)
        return "\n".join(formatted)
    
    def __call__(self, input_text, **kwargs):
        """使对象可调用，兼容某些 LangChain 用法"""
        return self.invoke(input_text, **kwargs)


def create_local_llm(temperature: float = 0.1) -> LocalModelLLMWrapper:
    """
    创建内置本地模型的 LLM 包装器
    
    用法与 ChatOpenAI 类似：
        llm = create_local_llm()
        response = llm.invoke("你好")
    """
    model = get_local_model()
    if not model.is_loaded():
        raise RuntimeError("内置本地模型未加载，请先在设置中加载模型")
    
    return LocalModelLLMWrapper(temperature=temperature)

