# -*- coding: utf-8 -*-
"""
异步文件扫描器 - 解决大量文件处理时的UI卡顿问题
"""

import os
import threading
import time
from pathlib import Path
from typing import List, Dict, Callable, Optional
from PySide6.QtCore import QObject, Signal, QThread, QTimer
from PySide6.QtWidgets import QApplication
from logger import logger
from lib.common import Common


class FileInfo:
    """文件信息类"""
    def __init__(self, file_path: str):
        self.file_path = file_path
        self.file_name = os.path.basename(file_path)
        self.file_size = 0
        self.file_type = Path(file_path).suffix.lower()
        self.is_valid = False
        self.error_msg = ""
        
        try:
            if os.path.exists(file_path):
                stat = os.stat(file_path)
                self.file_size = stat.st_size
                self.is_valid = not Common.is_temp_file(file_path)
                if not self.is_valid:
                    self.error_msg = "临时文件"
            else:
                self.error_msg = "文件不存在"
        except Exception as e:
            self.error_msg = f"读取失败: {str(e)}"


class AsyncFileScanner(QObject):
    """异步文件扫描器"""
    
    # 信号定义
    scan_started = Signal(str)  # 扫描开始，参数：根目录路径
    scan_progress = Signal(int, int, str)  # 扫描进度，参数：当前数量、总数量、当前文件
    scan_completed = Signal(list)  # 扫描完成，参数：文件信息列表
    scan_error = Signal(str)  # 扫描错误，参数：错误信息
    batch_ready = Signal(list)  # 批次就绪，参数：文件信息批次
    scan_stats = Signal(dict)  # 🔥 新增：扫描统计信息，参数：统计字典
    
    def __init__(self, parent=None):
        super().__init__(parent)
        self.is_scanning = False
        self.should_stop = False
        self.scan_thread = None
        
        # 配置参数
        self.batch_size = 500  # 每批处理的文件数量
        self.max_file_size = 100 * 1024 * 1024  # 100MB文件大小限制
        
        # 🎬 使用统一的文件扩展名定义
        from lib.constant import supported_extensions
        self.supported_extensions = supported_extensions
        
    def start_scan(self, root_paths: List[str], max_files: int = None):
        """开始异步扫描"""
        if self.is_scanning:
            logger.warning("扫描已在进行中")
            return False
            
        self.is_scanning = True
        self.should_stop = False
        
        # 创建扫描线程
        self.scan_thread = threading.Thread(
            target=self._scan_worker,
            args=(root_paths, max_files),
            daemon=True
        )
        self.scan_thread.start()
        return True
    
    def stop_scan(self):
        """停止扫描"""
        self.should_stop = True
        if self.scan_thread and self.scan_thread.is_alive():
            self.scan_thread.join(timeout=2.0)
        self.is_scanning = False
    
    def _scan_worker(self, root_paths: List[str], max_files: Optional[int]):
        """扫描工作线程"""
        try:
            all_files = []
            total_scanned = 0
            total_valid_files = 0  # 🔥 新增：有效文件总数（不包括因限制被跳过的）
            
            # 🔥 新增：统计信息
            stats = {
                'skipped_large_files': [],  # 被跳过的大文件
                'skipped_unsupported': 0,    # 不支持的文件类型数量
                'skipped_temp_files': 0,     # 临时文件数量
                'total_scanned': 0,          # 扫描总数
                'skipped_due_to_limit': 0,   # 🔥 新增：因数量限制而跳过的文件
                'max_files': max_files       # 🔥 新增：最大文件数限制
            }
            
            for root_path in root_paths:
                if self.should_stop:
                    break
                    
                self.scan_started.emit(root_path)
                logger.info(f"开始扫描目录: {root_path}")
                
                # 快速预扫描获取文件总数
                estimated_total = self._quick_count_files(root_path)
                logger.info(f"预估文件数量: {estimated_total}")
                
                # 递归扫描文件
                current_batch = []
                
                for file_info in self._scan_directory_recursive(root_path):
                    if self.should_stop:
                        break
                        
                    total_scanned += 1
                    stats['total_scanned'] = total_scanned
                    
                    # 🔥 修改：先判断文件是否有效，再检查数量限制
                    # 这样可以统计有多少有效文件因为限制而被跳过
                    
                    # 检查文件数量限制
                    if max_files and len(all_files) >= max_files:
                        # 继续扫描以统计被跳过的有效文件数量
                        if self._is_file_valid_for_count(file_info):
                            total_valid_files += 1
                        continue  # 跳过但继续统计
                    
                    # 🔥 新增：统计过滤原因
                    if not file_info.is_valid:
                        if not file_info.error_msg or '临时文件' in file_info.error_msg:
                            stats['skipped_temp_files'] += 1
                        continue
                    
                    if file_info.file_type not in self.supported_extensions:
                        stats['skipped_unsupported'] += 1
                        continue
                    
                    if file_info.file_size > self.max_file_size:
                        stats['skipped_large_files'].append({
                            'name': file_info.file_name,
                            'size': file_info.file_size
                        })
                        continue
                    
                    # 过滤有效文件
                    if self._is_file_valid(file_info):
                        total_valid_files += 1  # 🔥 新增：统计有效文件总数
                        current_batch.append(file_info)
                        all_files.append(file_info)
                        
                        # 发送进度信号
                        self.scan_progress.emit(len(all_files), estimated_total, file_info.file_name)
                        
                        # 批次处理
                        if len(current_batch) >= self.batch_size:
                            self.batch_ready.emit(current_batch.copy())
                            current_batch.clear()
                    
                    # 适当延迟，避免CPU占用过高
                    if total_scanned % 100 == 0:
                        time.sleep(0.01)
                
                # 处理最后一批
                if current_batch:
                    self.batch_ready.emit(current_batch)
            
            # 🔥 新增：计算因数量限制而跳过的文件数
            if max_files and total_valid_files > max_files:
                stats['skipped_due_to_limit'] = total_valid_files - max_files
                logger.warning(f"因数量限制跳过了 {stats['skipped_due_to_limit']} 个有效文件")
            
            # 🔥 新增：发送统计信息
            self.scan_stats.emit(stats)
            
            # 扫描完成
            self.scan_completed.emit(all_files)
            logger.info(f"扫描完成，共找到 {len(all_files)} 个有效文件，"
                       f"跳过大文件 {len(stats['skipped_large_files'])} 个，"
                       f"不支持格式 {stats['skipped_unsupported']} 个，"
                       f"因数量限制跳过 {stats.get('skipped_due_to_limit', 0)} 个")
            
        except Exception as e:
            logger.error(f"扫描过程出错: {str(e)}")
            self.scan_error.emit(str(e))
        finally:
            self.is_scanning = False
    
    def _quick_count_files(self, root_path: str) -> int:
        """快速统计文件数量（不详细分析）"""
        try:
            count = 0
            for root, dirs, files in os.walk(root_path):
                if self.should_stop:
                    break
                    
                # 跳过输出目录
                if 'output' in root:
                    continue
                    
                count += len(files)
                
                # 避免扫描过深的目录
                if count > 10000:  # 超过1万个文件就停止统计
                    break
                    
            return min(count, 10000)  # 最多返回1万
        except Exception as e:
            logger.warning(f"快速统计文件数量失败: {str(e)}")
            return 1000  # 默认估计值
    
    def _scan_directory_recursive(self, root_path: str):
        """递归扫描目录，生成器方式返回文件信息"""
        try:
            # 🔥 新增：先检查目录是否可访问
            if not os.path.exists(root_path):
                logger.error(f"目录不存在: {root_path}")
                self.scan_error.emit(f"目录不存在: {root_path}")
                return
                
            if not os.access(root_path, os.R_OK):
                logger.error(f"无权限访问目录: {root_path}")
                self.scan_error.emit(f"无权限访问该文件夹，请检查权限设置")
                return
            
            for root, dirs, files in os.walk(root_path):
                if self.should_stop:
                    break
                    
                # 跳过输出目录
                if 'output' in root:
                    continue
                
                # 过滤目录（跳过隐藏目录）
                dirs[:] = [d for d in dirs if not d.startswith('.')]
                
                for file_name in files:
                    if self.should_stop:
                        break
                        
                    file_path = os.path.join(root, file_name)
                    file_info = FileInfo(file_path)
                    yield file_info
                    
        except PermissionError as e:
            logger.error(f"权限错误: {root_path}: {str(e)}")
            self.scan_error.emit(f"无法访问文件夹，权限不足")
        except Exception as e:
            logger.error(f"递归扫描目录失败 {root_path}: {str(e)}")
            self.scan_error.emit(f"扫描失败: {str(e)}")
    
    def _is_file_valid(self, file_info: FileInfo) -> bool:
        """检查文件是否有效"""
        if not file_info.is_valid:
            return False
            
        # 检查文件扩展名
        if file_info.file_type not in self.supported_extensions:
            return False
            
        # 检查文件大小
        if file_info.file_size > self.max_file_size:
            logger.warning(f"文件过大，跳过: {file_info.file_name} ({file_info.file_size / 1024 / 1024:.1f}MB)")
            return False
            
        # 检查文件名
        if file_info.file_name.startswith('.') or file_info.file_name.startswith('~'):
            return False
            
        return True
    
    def _is_file_valid_for_count(self, file_info: FileInfo) -> bool:
        """🔥 新增：快速检查文件是否有效（用于统计被跳过的文件数量）"""
        if not file_info.is_valid:
            return False
            
        # 检查文件扩展名
        if file_info.file_type not in self.supported_extensions:
            return False
            
        # 检查文件大小
        if file_info.file_size > self.max_file_size:
            return False
            
        # 检查文件名
        if file_info.file_name.startswith('.') or file_info.file_name.startswith('~'):
            return False
            
        return True


class SmartFileLoader(QObject):
    """智能文件加载器 - 带UI优化"""
    
    # 信号定义
    loading_started = Signal(int)  # 开始加载，参数：总文件数
    loading_progress = Signal(int, str)  # 加载进度，参数：已加载数量、当前文件
    loading_completed = Signal()  # 加载完成
    ui_update_ready = Signal(list)  # UI更新就绪，参数：UI项目列表
    
    def __init__(self, parent=None):
        super().__init__(parent)
        self.ui_update_timer = QTimer()
        self.ui_update_timer.timeout.connect(self._flush_ui_updates)
        self.ui_update_timer.setSingleShot(False)
        self.pending_ui_updates = []
        self.update_lock = threading.Lock()
        
    def start_loading(self, file_list: List[FileInfo], tree_widget):
        """开始智能加载文件到UI"""
        self.tree_widget = tree_widget
        self.file_list = file_list
        self.loading_started.emit(len(file_list))
        
        # 开始批量UI更新定时器
        self.ui_update_timer.start(100)  # 每100ms更新一次UI
        
        # 启动加载线程
        loading_thread = threading.Thread(
            target=self._loading_worker,
            daemon=True
        )
        loading_thread.start()
    
    def add_files_async(self, file_batch: List[Dict]):
        """异步添加文件批次（用于处理批次就绪信号）"""
        try:
            # 将字典格式转换为UI更新格式
            ui_items = []
            for file_data in file_batch:
                ui_item = {
                    'file_path': file_data.get('file_path', ''),
                    'file_name': file_data.get('file_name', ''),
                    'file_size': file_data.get('file_size', 0),
                    'file_type': file_data.get('file_type', ''),
                    'status': file_data.get('status', '待处理')
                }
                ui_items.append(ui_item)
            
            # 添加到待更新队列
            with self.update_lock:
                self.pending_ui_updates.extend(ui_items)
            
            # 如果定时器未启动，启动它
            if not self.ui_update_timer.isActive():
                self.ui_update_timer.start(100)
                
        except Exception as e:
            logger.error(f"异步添加文件批次失败: {str(e)}")
    
    def _loading_worker(self):
        """加载工作线程"""
        try:
            batch_size = 50  # 每批50个文件
            
            for i in range(0, len(self.file_list), batch_size):
                batch = self.file_list[i:i + batch_size]
                ui_items = []
                
                for file_info in batch:
                    # 创建UI项目数据（不直接操作UI组件）
                    ui_item = {
                        'file_path': file_info.file_path,
                        'file_name': file_info.file_name,
                        'file_size': file_info.file_size,
                        'file_type': file_info.file_type
                    }
                    ui_items.append(ui_item)
                    
                    # 发送进度信号
                    self.loading_progress.emit(i + len(ui_items), file_info.file_name)
                
                # 添加到待更新队列
                with self.update_lock:
                    self.pending_ui_updates.extend(ui_items)
                
                # 适当延迟
                time.sleep(0.02)
            
            self.loading_completed.emit()
            
        except Exception as e:
            logger.error(f"文件加载工作线程出错: {str(e)}")
        finally:
            # 停止UI更新定时器
            self.ui_update_timer.stop()
            # 最后一次更新
            self._flush_ui_updates()
    
    def _flush_ui_updates(self):
        """批量更新UI"""
        with self.update_lock:
            if self.pending_ui_updates:
                # 发送UI更新信号
                self.ui_update_ready.emit(self.pending_ui_updates.copy())
                self.pending_ui_updates.clear()


class PerformanceMonitor(QObject):
    """性能监控器"""
    
    performance_update = Signal(dict)  # 性能数据更新
    
    def __init__(self, parent=None):
        super().__init__(parent)
        self.monitor_timer = QTimer()
        self.monitor_timer.timeout.connect(self._collect_metrics)
        self.start_time = time.time()
        self.file_count = 0
        self.error_count = 0
        
    def start_monitoring(self):
        """开始性能监控"""
        self.start_time = time.time()
        self.monitor_timer.start(1000)  # 每秒更新一次
    
    def stop_monitoring(self):
        """停止性能监控"""
        self.monitor_timer.stop()
    
    def record_file_processed(self):
        """记录文件处理"""
        self.file_count += 1
    
    def record_error(self):
        """记录错误"""
        self.error_count += 1
    
    def _collect_metrics(self):
        """收集性能指标"""
        try:
            import psutil
            
            # 系统资源使用情况
            cpu_percent = psutil.cpu_percent()
            memory_info = psutil.virtual_memory()
            memory_percent = memory_info.percent
            
            # 处理统计
            elapsed_time = time.time() - self.start_time
            processing_rate = self.file_count / max(elapsed_time, 1)
            error_rate = self.error_count / max(self.file_count, 1)
            
            metrics = {
                'cpu_percent': cpu_percent,
                'memory_percent': memory_percent,
                'processing_rate': processing_rate,
                'error_rate': error_rate,
                'total_files': self.file_count,
                'total_errors': self.error_count,
                'elapsed_time': elapsed_time
            }
            
            self.performance_update.emit(metrics)
            
        except ImportError:
            # 如果psutil未安装，使用基础监控
            elapsed_time = time.time() - self.start_time
            processing_rate = self.file_count / max(elapsed_time, 1)
            
            metrics = {
                'cpu_percent': 0,
                'memory_percent': 0,
                'processing_rate': processing_rate,
                'error_rate': self.error_count / max(self.file_count, 1),
                'total_files': self.file_count,
                'total_errors': self.error_count,
                'elapsed_time': elapsed_time
            }
            
            self.performance_update.emit(metrics)
        except Exception as e:
            logger.error(f"收集性能指标失败: {str(e)}")


class SmartThreadManager:
    """智能线程管理器"""
    
    def __init__(self):
        self.cpu_count = os.cpu_count() or 4
        self.optimal_workers = self._calculate_optimal_workers()
        
    def _calculate_optimal_workers(self) -> int:
        """计算最优线程数"""
        try:
            import psutil
            
            # 获取系统信息
            memory_gb = psutil.virtual_memory().total / (1024**3)
            cpu_count = self.cpu_count
            
            # 基于系统配置计算最优线程数
            if memory_gb >= 16 and cpu_count >= 8:
                base_workers = min(12, cpu_count)
            elif memory_gb >= 8 and cpu_count >= 4:
                base_workers = min(8, cpu_count)
            else:
                base_workers = min(4, cpu_count)
                
            logger.info(f"系统配置: {cpu_count}核心, {memory_gb:.1f}GB内存, 建议线程数: {base_workers}")
            return base_workers
            
        except ImportError:
            # 如果psutil未安装，使用保守估计
            return min(4, self.cpu_count)
        except Exception as e:
            logger.error(f"计算最优线程数失败: {str(e)}")
            return 4
    
    def get_workers_for_task(self, task_type: str, file_count: int, file_complexity: str = "medium") -> int:
        """根据任务类型获取推荐线程数"""
        base_workers = self.optimal_workers
        
        if task_type == "file_scan":
            # 文件扫描任务
            return min(2, base_workers)  # 扫描不需要太多线程
        elif task_type == "local_model":
            # 本地模型任务
            if file_complexity == "simple":
                return min(base_workers, 8)
            elif file_complexity == "complex":
                return min(base_workers // 2, 4)
            else:
                return min(base_workers, 6)
        elif task_type == "online_model":
            # 在线模型任务
            return min(base_workers, 10)
        elif task_type == "image_processing":
            # 图片处理任务
            return min(base_workers // 2, 4)
        else:
            return base_workers


class CacheManager:
    """缓存管理器"""
    
    def __init__(self, cache_dir: str = None):
        if cache_dir is None:
            cache_dir = os.path.join(os.path.expanduser('~'), '.fileneatai', 'cache')
        
        self.cache_dir = cache_dir
        os.makedirs(cache_dir, exist_ok=True)
        
        # 内存缓存
        self.memory_cache = {}
        self.cache_stats = {
            'hits': 0,
            'misses': 0,
            'total_requests': 0
        }
        
        # 🚀 启动时自动验证并清理损坏的缓存文件
        try:
            self.validate_and_clean_cache()
        except Exception as e:
            logger.warning(f"初始化缓存验证失败: {str(e)}")
    
    def get_cache_key(self, file_path: str, content_hash: str = None) -> str:
        """生成缓存键"""
        import hashlib
        
        if content_hash:
            key_data = f"{file_path}_{content_hash}"
        else:
            # 使用文件路径和修改时间生成键
            try:
                mtime = os.path.getmtime(file_path)
                key_data = f"{file_path}_{mtime}"
            except:
                key_data = file_path
        
        return hashlib.md5(key_data.encode()).hexdigest()
    
    def get_cached_result(self, cache_key: str):
        """获取缓存结果"""
        self.cache_stats['total_requests'] += 1
        
        # 先检查内存缓存
        if cache_key in self.memory_cache:
            self.cache_stats['hits'] += 1
            return self.memory_cache[cache_key]
        
        # 检查磁盘缓存
        cache_file = os.path.join(self.cache_dir, f"{cache_key}.json")
        if os.path.exists(cache_file):
            try:
                import json
                with open(cache_file, 'r', encoding='utf-8') as f:
                    result = json.load(f)
                
                # 加载到内存缓存
                self.memory_cache[cache_key] = result
                self.cache_stats['hits'] += 1
                return result
            except json.JSONDecodeError as e:
                # JSON格式损坏，删除损坏的缓存文件并自动修复
                logger.warning(f"缓存文件格式损坏，已自动删除: {os.path.basename(cache_file)} (错误: {str(e)})")
                try:
                    os.remove(cache_file)
                    logger.info(f"已删除损坏的缓存文件，系统将重新生成")
                except Exception as remove_error:
                    logger.error(f"删除损坏缓存文件失败: {str(remove_error)}")
            except Exception as e:
                # 其他错误（如文件权限问题）
                logger.warning(f"读取缓存文件失败: {str(e)} - 文件: {os.path.basename(cache_file)}")
        
        self.cache_stats['misses'] += 1
        return None
    
    def save_result(self, cache_key: str, result: dict):
        """保存结果到缓存"""
        try:
            # 保存到内存缓存
            self.memory_cache[cache_key] = result
            
            # 使用原子写入保存到磁盘缓存，避免文件损坏
            cache_file = os.path.join(self.cache_dir, f"{cache_key}.json")
            temp_file = cache_file + '.tmp'
            
            import json
            try:
                # 先写入临时文件
                with open(temp_file, 'w', encoding='utf-8') as f:
                    json.dump(result, f, ensure_ascii=False, indent=2)
                    # 确保数据写入磁盘
                    f.flush()
                    if hasattr(os, 'fsync'):
                        os.fsync(f.fileno())
                
                # 原子性地替换原文件
                if os.path.exists(cache_file):
                    os.replace(temp_file, cache_file)
                else:
                    os.rename(temp_file, cache_file)
                    
            except Exception as write_error:
                # 如果写入失败，清理临时文件
                if os.path.exists(temp_file):
                    try:
                        os.remove(temp_file)
                    except:
                        pass
                raise write_error
                
        except Exception as e:
            logger.error(f"保存缓存失败: {str(e)}")
    
    def get_cache_stats(self) -> dict:
        """获取缓存统计信息"""
        hit_rate = self.cache_stats['hits'] / max(self.cache_stats['total_requests'], 1)
        return {
            **self.cache_stats,
            'hit_rate': hit_rate,
            'memory_cache_size': len(self.memory_cache)
        }
    
    def clear_cache(self):
        """清理缓存"""
        self.memory_cache.clear()
        try:
            import shutil
            shutil.rmtree(self.cache_dir)
            os.makedirs(self.cache_dir, exist_ok=True)
            logger.info("缓存已清理")
        except Exception as e:
            logger.error(f"清理缓存失败: {str(e)}")
    
    def validate_and_clean_cache(self):
        """验证并清理损坏的缓存文件"""
        if not os.path.exists(self.cache_dir):
            return
        
        try:
            import json
            corrupted_files = []
            total_files = 0
            
            for filename in os.listdir(self.cache_dir):
                if not filename.endswith('.json'):
                    continue
                
                total_files += 1
                cache_file = os.path.join(self.cache_dir, filename)
                
                try:
                    with open(cache_file, 'r', encoding='utf-8') as f:
                        json.load(f)  # 尝试解析JSON
                except json.JSONDecodeError:
                    # 发现损坏的文件
                    corrupted_files.append(filename)
                    try:
                        os.remove(cache_file)
                    except Exception:
                        pass
                except Exception:
                    pass
            
            if corrupted_files:
                logger.info(f"缓存验证完成: 发现并清理了 {len(corrupted_files)} 个损坏的缓存文件（共 {total_files} 个）")
            else:
                logger.debug(f"缓存验证完成: 所有 {total_files} 个缓存文件完好")
                
        except Exception as e:
            logger.error(f"缓存验证失败: {str(e)}")


# 全局实例
thread_manager = SmartThreadManager()
cache_manager = CacheManager()
