"""Co-location pattern mining core module."""

import json
import logging
import math
import numpy as np
from collections import defaultdict
from typing import Dict, List, Set, Tuple, Any, Optional

logger = logging.getLogger(__name__)

# Try to import torch for GPU acceleration
try:
    import torch
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False


class CoLocationMiner:
    """Co-location pattern miner using participation index."""
    
    def __init__(self, data_path: str, distance_threshold: float = 16.0, use_gpu: bool = False):
        """Initialize miner with data.
        
        Args:
            data_path: Path to JSON data file
            distance_threshold: Distance threshold for neighbor relationship
            use_gpu: Whether to use GPU for distance calculations (if available)
        """
        self.distance_threshold = distance_threshold
        self.data = self._load_data(data_path)
        self.feature_counts = self._count_features()
        # 预先构建实例字符串到实例对象的映射，用于快速查找
        self._instance_map = {}
        for item in self.data:
            inst_str = f"{item['type']}{item['id']}"
            self._instance_map[inst_str] = item
        
        # GPU加速设置
        self.use_gpu = use_gpu and TORCH_AVAILABLE
        if self.use_gpu and torch.cuda.is_available():
            self.device = "cuda"
            logger.info(f"GPU acceleration enabled for pattern mining (device: {torch.cuda.get_device_name(0)})")
        else:
            self.device = "cpu"
            if use_gpu and not TORCH_AVAILABLE:
                logger.warning("PyTorch not available, GPU acceleration disabled")
            elif use_gpu:
                logger.warning("CUDA not available, using CPU")
        
        # 预先提取坐标数组用于批量计算
        self._coords = None
        self._build_coords_array()
        
        logger.info(f"Loaded {len(self.data)} instances with {len(self.feature_counts)} features")
    
    def _load_data(self, data_path: str) -> List[Dict[str, Any]]:
        """Load data from JSON file.
        
        Args:
            data_path: Path to data file
            
        Returns:
            List of data instances
        """
        with open(data_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        # 验证数据格式
        for item in data:
            if 'type' not in item or 'x' not in item or 'y' not in item:
                raise ValueError(f"Invalid data format: {item}")
        
        return data
    
    def _count_features(self) -> Dict[str, int]:
        """Count instances for each feature type.
        
        Returns:
            Dictionary mapping feature type to count
        """
        counts = defaultdict(int)
        for item in self.data:
            counts[item['type']] += 1
        return dict(counts)
    
    def _build_coords_array(self):
        """Build coordinate arrays for batch distance calculation."""
        n = len(self.data)
        self._coords = np.zeros((n, 2), dtype=np.float32)
        for i, item in enumerate(self.data):
            self._coords[i, 0] = item['x']
            self._coords[i, 1] = item['y']
        
        # 如果使用GPU，预先创建阈值张量
        if self.use_gpu and TORCH_AVAILABLE:
            self._distance_threshold_gpu = torch.tensor(self.distance_threshold, dtype=torch.float32, device=self.device)
    
    def _is_neighbor(self, instance1: Dict, instance2: Dict) -> bool:
        """Check if two instances are neighbors.
        
        Args:
            instance1: First instance
            instance2: Second instance
            
        Returns:
            True if instances are neighbors
        """
        distance = math.sqrt(
            (instance1['x'] - instance2['x']) ** 2 + 
            (instance1['y'] - instance2['y']) ** 2
        )
        return distance <= self.distance_threshold
    
    def _create_table_instances_2(self) -> Dict[str, List[List[str]]]:
        """Create 2-order table instances using optimized batch computation.
        
        Uses spatial indexing and batch distance calculation for performance.
        Merges patterns with same features but different order (e.g., "a,b" and "b,a").
        
        Returns:
            Dictionary mapping pattern to table instances (pattern is sorted alphabetically)
        """
        import time
        start_time = time.time()
        logger.info("Starting 2-order table instance creation with batch optimization...")
        
        table_instances = defaultdict(list)
        n = len(self.data)
        
        # 按类型分组实例索引，便于批量处理
        type_indices = defaultdict(list)
        for i, item in enumerate(self.data):
            type_indices[item['type']].append(i)
        
        logger.info(f"Grouped instances by type: {len(type_indices)} types")
        
        # 获取所有类型对（避免重复）
        types = sorted(type_indices.keys())
        type_pairs = []
        for i in range(len(types)):
            for j in range(i + 1, len(types)):
                type_pairs.append((types[i], types[j]))
        
        logger.info(f"Processing {len(type_pairs)} type pairs...")
        
        # 批量处理每个类型对
        processed_pairs = 0
        for type1, type2 in type_pairs:
            indices1 = type_indices[type1]
            indices2 = type_indices[type2]
            
            if not indices1 or not indices2:
                continue
            
            # 批量计算距离矩阵
            coords1 = self._coords[indices1]  # shape: (n1, 2)
            coords2 = self._coords[indices2]  # shape: (n2, 2)
            
            if self.use_gpu and TORCH_AVAILABLE:
                # GPU批量计算
                coords1_gpu = torch.from_numpy(coords1).to(self.device)
                coords2_gpu = torch.from_numpy(coords2).to(self.device)
                
                # 计算所有距离对: (n1, 1, 2) - (1, n2, 2) = (n1, n2, 2)
                diff = coords1_gpu.unsqueeze(1) - coords2_gpu.unsqueeze(0)
                distances = torch.sqrt(torch.sum(diff ** 2, dim=2))  # shape: (n1, n2)
                
                # 找到邻近的实例对
                neighbor_mask = distances <= self._distance_threshold_gpu
                neighbor_indices = torch.nonzero(neighbor_mask, as_tuple=False).cpu().numpy()
            else:
                # CPU批量计算（使用NumPy向量化）
                # 计算所有距离对: (n1, 1, 2) - (1, n2, 2) = (n1, n2, 2)
                diff = coords1[:, np.newaxis, :] - coords2[np.newaxis, :, :]
                distances = np.sqrt(np.sum(diff ** 2, axis=2))  # shape: (n1, n2)
                
                # 找到邻近的实例对
                neighbor_mask = distances <= self.distance_threshold
                neighbor_indices = np.column_stack(np.where(neighbor_mask))
            
            # 统一模式表示：按字母顺序排序
            pattern_types = sorted([type1, type2])
            pattern = f"{pattern_types[0]},{pattern_types[1]}"
            
            # 添加邻近实例对
            for idx1_pos, idx2_pos in neighbor_indices:
                i1 = indices1[idx1_pos]
                i2 = indices2[idx2_pos]
                
                inst1 = self.data[i1]
                inst2 = self.data[i2]
                
                # 实例按模式顺序排列
                if inst1['type'] == pattern_types[0]:
                    instance = [f"{inst1['type']}{inst1['id']}", f"{inst2['type']}{inst2['id']}"]
                else:
                    instance = [f"{inst2['type']}{inst2['id']}", f"{inst1['type']}{inst1['id']}"]
                
                table_instances[pattern].append(instance)
            
            processed_pairs += 1
            if processed_pairs % 10 == 0:
                elapsed = time.time() - start_time
                logger.info(f"Processed {processed_pairs}/{len(type_pairs)} type pairs ({elapsed:.1f}s elapsed)")
        
        elapsed = time.time() - start_time
        total_instances = sum(len(v) for v in table_instances.values())
        logger.info(f"Created {len(table_instances)} candidate 2-order patterns with {total_instances} total instances in {elapsed:.2f}s")
        
        return dict(table_instances)
    
    def _calculate_participation_index(self, pattern: Set[str], 
                                       table_instances: Dict[str, List[List[str]]]) -> float:
        """Calculate participation index for a pattern.
        
        Args:
            pattern: Set of feature types in the pattern
            table_instances: Table instances for the pattern (key is pattern_str, value is list of instances)
            
        Returns:
            Participation index value
        """
        if not table_instances:
            return 0.0
        
        # 计算每个特征的参与率
        participation_rates = {}
        for feature in pattern:
            # 统计该特征参与模式的不同实例数
            # 例如：对于模式 {A, B}，统计有多少个不同的A实例参与了该模式
            participating_instances = set()
            
            # 遍历所有表实例
            for pattern_str, instances in table_instances.items():
                # 检查这个模式是否包含当前特征
                if feature in pattern_str.split(','):
                    # 对于每个表实例，提取该特征的实例标识
                    for instance in instances:
                        # instance 格式: ["A1", "B2"] 或 ["A1", "B2", "C3"]
                        for inst_str in instance:
                            # 检查是否是当前特征的实例（例如 "A1" 中的 "A"）
                            if inst_str.startswith(feature):
                                participating_instances.add(inst_str)
            
            # 参与率 = 参与该模式的不同实例数 / 该特征的总实例数
            total_instances = self.feature_counts.get(feature, 1)
            participation_rates[feature] = len(participating_instances) / total_instances if total_instances > 0 else 0.0
        
        # 参与指数 = min(所有特征的参与率)
        return min(participation_rates.values()) if participation_rates else 0.0
    
    def _are_instances_neighbors(self, inst1_str: str, inst2_str: str) -> bool:
        """Check if two instances are neighbors by directly calculating distance.
        
        Args:
            inst1_str: First instance string (e.g., "Restaurant1")
            inst2_str: Second instance string (e.g., "Cafe2")
            
        Returns:
            True if the two instances are neighbors (distance <= threshold)
        """
        # 从预构建的映射中快速获取实例对象
        inst1 = self._instance_map.get(inst1_str)
        inst2 = self._instance_map.get(inst2_str)
        
        if inst1 is None or inst2 is None:
            return False
        
        # 直接计算距离判断是否邻近
        return self._is_neighbor(inst1, inst2)
    
    def _generate_k_order_patterns(self, k_minus_1_patterns: Dict[str, List[List[str]]]) -> Dict[str, List[List[str]]]:
        """Generate k-order patterns from (k-1)-order patterns using optimized batch computation.
        
        Args:
            k_minus_1_patterns: (k-1)-order table instances
            
        Returns:
            k-order table instances
        """
        import time
        start_time = time.time()
        logger.info(f"Generating k-order patterns from {len(k_minus_1_patterns)} (k-1)-order patterns...")
        
        k_patterns = defaultdict(list)
        pattern_keys = sorted(k_minus_1_patterns.keys())
        
        # 构建前缀索引以加速查找
        prefix_index = defaultdict(list)  # key: prefix, value: list of (pattern_key, pattern_list)
        for pattern_key in pattern_keys:
            pattern_list = pattern_key.split(',')
            if len(pattern_list) > 1:
                prefix = ','.join(pattern_list[:-1])
                prefix_index[prefix].append((pattern_key, pattern_list))
        
        logger.info(f"Built prefix index with {len(prefix_index)} prefixes")
        
        processed_pairs = 0
        total_pairs = sum(len(prefix_index[p]) * (len(prefix_index[p]) - 1) // 2 for p in prefix_index)
        logger.info(f"Total pattern pairs to process: {total_pairs}")
        
        # 流式处理：分批收集和批量计算，避免内存爆炸
        batch_size = 100000  # 每批处理10万个检查
        current_batch = []
        total_checks = 0
        batch_check_time = 0
        
        def process_batch(batch_checks):
            """处理一批邻近关系检查"""
            if not batch_checks:
                return
            
            # 提取所有需要检查的实例字符串
            last1_strs = [check['last1'] for check in batch_checks]
            last2_strs = [check['last2'] for check in batch_checks]
            
            # 批量获取实例对象
            inst1_objs = [self._instance_map.get(s) for s in last1_strs]
            inst2_objs = [self._instance_map.get(s) for s in last2_strs]
            
            # 批量计算距离
            if self.use_gpu and TORCH_AVAILABLE and len(batch_checks) > 1000:
                # GPU批量计算
                coords1 = np.array([[obj['x'], obj['y']] if obj else [0, 0] for obj in inst1_objs], dtype=np.float32)
                coords2 = np.array([[obj['x'], obj['y']] if obj else [0, 0] for obj in inst2_objs], dtype=np.float32)
                
                coords1_gpu = torch.from_numpy(coords1).to(self.device)
                coords2_gpu = torch.from_numpy(coords2).to(self.device)
                
                diff = coords1_gpu - coords2_gpu
                distances = torch.sqrt(torch.sum(diff ** 2, dim=1))
                neighbor_mask = distances <= self._distance_threshold_gpu
                neighbor_results = neighbor_mask.cpu().numpy()
            else:
                # CPU批量计算（NumPy向量化）
                coords1 = np.array([[obj['x'], obj['y']] if obj else [0, 0] for obj in inst1_objs], dtype=np.float32)
                coords2 = np.array([[obj['x'], obj['y']] if obj else [0, 0] for obj in inst2_objs], dtype=np.float32)
                
                diff = coords1 - coords2
                distances = np.sqrt(np.sum(diff ** 2, axis=1))
                neighbor_results = distances <= self.distance_threshold
            
            # 根据批量检查结果添加实例
            for check, is_neighbor in zip(batch_checks, neighbor_results):
                if is_neighbor:
                    new_instance = check['inst1'] + [check['inst2'][-1]]
                    k_patterns[check['pattern']].append(new_instance)
        
        # 使用索引优化模式合并（流式处理）
        for prefix, pattern_list in prefix_index.items():
            if len(pattern_list) < 2:
                continue
            
            # 对于每个前缀，合并所有可能的模式对
            for i in range(len(pattern_list) - 1):
                pattern1_key, pattern1_list = pattern_list[i]
                inst1_list = k_minus_1_patterns[pattern1_key]
                
                for j in range(i + 1, len(pattern_list)):
                    pattern2_key, pattern2_list = pattern_list[j]
                    
                    # 确保最后一个特征不同
                    if pattern1_list[-1] != pattern2_list[-1]:
                        # 合并模式
                        new_pattern = ','.join(pattern1_list + [pattern2_list[-1]])
                        
                        inst2_list = k_minus_1_patterns[pattern2_key]
                        
                        # 限制处理的实例数量，避免内存爆炸
                        max_instances_per_pair = 50000  # 每个模式对最多处理5万个实例
                        if len(inst1_list) * len(inst2_list) > max_instances_per_pair:
                            # 如果实例对太多，采样处理
                            import random
                            sample_size = int(np.sqrt(max_instances_per_pair))
                            if len(inst1_list) > sample_size:
                                inst1_list = random.sample(inst1_list, sample_size)
                            if len(inst2_list) > sample_size:
                                inst2_list = random.sample(inst2_list, sample_size)
                            logger.debug(f"Sampling instances for pattern pair: {pattern1_key} x {pattern2_key}")
                        
                        # 为pattern1的实例构建前缀索引（加速匹配）
                        inst1_prefix_map = defaultdict(list)
                        for inst1 in inst1_list:
                            if len(inst1) > 1:
                                inst_prefix = tuple(inst1[:-1])
                                inst1_prefix_map[inst_prefix].append(inst1)
                        
                        # 流式收集需要检查邻近关系的实例对
                        for inst2 in inst2_list:
                            if len(inst2) > 1:
                                inst_prefix = tuple(inst2[:-1])
                                if inst_prefix in inst1_prefix_map:
                                    # 找到匹配的前缀，需要检查最后两个实例是否邻近
                                    for inst1 in inst1_prefix_map[inst_prefix]:
                                        last_inst1_str = inst1[-1]
                                        last_inst2_str = inst2[-1]
                                        
                                        current_batch.append({
                                            'pattern': new_pattern,
                                            'inst1': inst1,
                                            'inst2': inst2,
                                            'last1': last_inst1_str,
                                            'last2': last_inst2_str
                                        })
                                        
                                        # 当批次达到大小时，立即处理
                                        if len(current_batch) >= batch_size:
                                            batch_start = time.time()
                                            process_batch(current_batch)
                                            batch_check_time += time.time() - batch_start
                                            total_checks += len(current_batch)
                                            current_batch = []  # 清空批次
                        
                        processed_pairs += 1
                        if processed_pairs % 50 == 0:
                            elapsed = time.time() - start_time
                            logger.info(f"Processed {processed_pairs}/{total_pairs} pattern pairs, checked {total_checks} neighbors ({elapsed:.1f}s elapsed)")
        
        # 处理剩余的批次
        if current_batch:
            batch_start = time.time()
            process_batch(current_batch)
            batch_check_time += time.time() - batch_start
            total_checks += len(current_batch)
        
        elapsed = time.time() - start_time
        total_instances = sum(len(v) for v in k_patterns.values())
        logger.info(f"Total neighbor checks: {total_checks}, batch check time: {batch_check_time:.2f}s")
        logger.info(f"Generated {len(k_patterns)} k-order patterns with {total_instances} instances in {elapsed:.2f}s (batch check: {batch_check_time:.2f}s)")
        
        return dict(k_patterns)
    
    def mine_patterns(self, min_participation: float = 0.6, 
                     max_pattern_size: int = 5,
                     priority: str = "confidence") -> List[Dict[str, Any]]:
        """Mine co-location patterns.
        
        Args:
            min_participation: Minimum participation index threshold
            max_pattern_size: Maximum pattern size
            priority: Priority for sorting ("confidence", "participation", "size")
            
        Returns:
            List of discovered patterns with metadata
        """
        logger.info(f"Mining patterns with min_participation={min_participation}, "
                   f"max_pattern_size={max_pattern_size}")
        
        all_patterns = []
        current_patterns = self._create_table_instances_2()
        logger.info(f"Created {len(current_patterns)} candidate 2-order patterns with {sum(len(v) for v in current_patterns.values())} total instances")
        k = 2
        
        # 挖掘2阶到max_pattern_size阶的模式
        while k <= max_pattern_size and current_patterns:
            logger.info(f"Mining {k}-order patterns, found {len(current_patterns)} candidates")
            
            # 过滤满足参与率阈值的模式
            valid_patterns = {}
            participation_stats = []
            for pattern_str, instances in current_patterns.items():
                pattern_set = set(pattern_str.split(','))
                participation = self._calculate_participation_index(pattern_set, {pattern_str: instances})
                participation_stats.append((pattern_str, participation, len(instances)))
                
                if participation >= min_participation:
                    valid_patterns[pattern_str] = instances
                    
                    # 计算置信度（简化版本：使用参与率作为置信度）
                    confidence = participation
                    
                    all_patterns.append({
                        'pattern': sorted(pattern_set),
                        'pattern_str': pattern_str,
                        'size': k,
                        'participation_index': participation,
                        'confidence': confidence,
                        'instance_count': len(instances),
                        'table_instances': instances,  # 保存完整表实例用于规则生成
                        'table_instances_sample': instances[:5]  # 保存前5个用于显示
                    })
            
            # 显示前几个模式的参与率（用于调试）
            if participation_stats:
                participation_stats.sort(key=lambda x: x[1], reverse=True)
                logger.info(f"Top 5 patterns by participation (threshold={min_participation}):")
                for pattern_str, participation, count in participation_stats[:5]:
                    status = "✓" if participation >= min_participation else "✗"
                    logger.info(f"  {status} {pattern_str}: participation={participation:.4f}, instances={count}")
            
            # 生成下一阶模式
            if k < max_pattern_size:
                current_patterns = self._generate_k_order_patterns(valid_patterns)
                k += 1
            else:
                break
        
        # 根据优先级排序
        if priority == "confidence":
            all_patterns.sort(key=lambda x: x['confidence'], reverse=True)
        elif priority == "participation":
            all_patterns.sort(key=lambda x: x['participation_index'], reverse=True)
        elif priority == "size":
            all_patterns.sort(key=lambda x: x['size'], reverse=True)
        
        logger.info(f"Found {len(all_patterns)} valid patterns")
        return all_patterns
    
    def _calculate_support(self, pattern_table_instances: List[List[str]]) -> float:
        """Calculate support for a pattern.
        
        Support = number of table instances / total possible neighbor pairs
        
        Args:
            pattern_table_instances: Table instances for the pattern
            
        Returns:
            Support value (0.0 to 1.0)
        """
        if not pattern_table_instances:
            return 0.0
        
        # 统计表实例总数
        total_instances = len(pattern_table_instances)
        
        # 计算总可能的邻居对数（作为分母）
        # 简化计算：使用所有实例对的数量
        total_pairs = len(self.data) * (len(self.data) - 1) / 2
        
        return total_instances / total_pairs if total_pairs > 0 else 0.0
    
    def _calculate_rule_confidence(self, antecedent: Set[str], consequent: Set[str], 
                                   pattern: Set[str], pattern_table_instances: List[List[str]]) -> float:
        """Calculate confidence for a rule: antecedent → consequent.
        
        For co-location rules, confidence = P(consequent near antecedent | antecedent exists)
        = |instances where antecedent and consequent co-occur| / |instances of antecedent|
        
        Args:
            antecedent: Left side of the rule (条件)
            consequent: Right side of the rule (结果)
            pattern: Full pattern containing both antecedent and consequent
            pattern_table_instances: Table instances for the full pattern
            
        Returns:
            Confidence value (0.0 to 1.0)
        """
        if not pattern_table_instances:
            return 0.0
        
        # 统计前件实例参与共现的次数
        # 对于规则 A → B，统计有多少个A的实例参与了{A, B}模式
        antecedent_instances_in_pattern = set()
        for instance in pattern_table_instances:
            for inst_str in instance:
                for feature in antecedent:
                    if inst_str.startswith(feature):
                        antecedent_instances_in_pattern.add(inst_str)
        
        # 计算前件特征的总实例数
        antecedent_total = sum(self.feature_counts.get(f, 0) for f in antecedent)
        
        if antecedent_total == 0:
            return 0.0
        
        # 置信度 = 参与共现的前件实例数 / 前件总实例数
        confidence = len(antecedent_instances_in_pattern) / antecedent_total
        
        return min(confidence, 1.0)  # 确保不超过1.0
    
    def _get_pattern_table_instances(self, pattern: Set[str]) -> Dict[str, List[List[str]]]:
        """Get table instances for a given pattern.
        
        Args:
            pattern: Set of feature types
            
        Returns:
            Dictionary mapping pattern string to table instances
        """
        if len(pattern) == 1:
            # 单特征模式没有表实例（需要至少2个特征）
            return {}
        
        # 如果是2阶模式，直接从已创建的2阶表实例中获取
        if len(pattern) == 2:
            pattern_str = ','.join(sorted(pattern))
            all_2_order = self._create_table_instances_2()
            if pattern_str in all_2_order:
                return {pattern_str: all_2_order[pattern_str]}
            return {}
        
        # 对于高阶模式，需要递归计算
        # 这里简化处理：返回空（实际应该从已挖掘的模式中获取）
        return {}
    
    def _calculate_lift(self, antecedent: Set[str], consequent: Set[str],
                       confidence: float, pattern_table_instances: List[List[str]]) -> float:
        """Calculate lift for a rule.
        
        Lift = confidence / P(consequent) = P(consequent | antecedent) / P(consequent)
        
        For co-location, P(consequent) = |consequent instances| / |total instances|
        
        Args:
            antecedent: Left side of the rule
            consequent: Right side of the rule
            confidence: Confidence of the rule
            pattern_table_instances: Table instances for the full pattern
            
        Returns:
            Lift value
        """
        # 计算后件的边际概率 P(consequent)
        # 即：后件特征的总实例数 / 所有实例总数
        consequent_total = sum(self.feature_counts.get(f, 0) for f in consequent)
        total_instances = len(self.data)
        
        if total_instances == 0 or consequent_total == 0:
            return 0.0
        
        consequent_prob = consequent_total / total_instances
        
        if consequent_prob == 0:
            return 0.0
        
        # 提升度 = 置信度 / 后件边际概率
        lift = confidence / consequent_prob
        
        return lift
    
    def generate_rules(self, patterns: List[Dict[str, Any]], 
                      min_confidence: float = 0.5,
                      min_lift: float = 1.0) -> List[Dict[str, Any]]:
        """Generate association rules from mined patterns.
        
        Args:
            patterns: List of mined patterns
            min_confidence: Minimum confidence threshold
            min_lift: Minimum lift threshold
            
        Returns:
            List of association rules with metrics
        """
        logger.info(f"Generating rules from {len(patterns)} patterns with min_confidence={min_confidence}, min_lift={min_lift}")
        
        all_rules = []
        
            # 为每个模式生成所有可能的规则
        for pattern_data in patterns:
            pattern = set(pattern_data['pattern'])
            pattern_str = pattern_data['pattern_str']
            table_instances = pattern_data.get('table_instances', [])
            
            # 如果表实例为空，跳过
            if not table_instances:
                logger.debug(f"Skipping pattern {pattern_str}: no table instances")
                continue
            
            # 计算模式的支持度
            pattern_support = self._calculate_support(table_instances)
            
            # 生成所有可能的规则（前件 → 后件）
            # 对于模式 {A, B, C}，生成：
            # A → B, A → C, B → A, B → C, C → A, C → B
            # A → {B, C}, B → {A, C}, C → {A, B}
            # {A, B} → C, {A, C} → B, {B, C} → A
            pattern_list = sorted(list(pattern))
            
            # 生成所有非空子集作为前件和后件
            from itertools import combinations
            
            for r in range(1, len(pattern)):
                # 生成所有大小为r的前件
                for antecedent_tuple in combinations(pattern_list, r):
                    antecedent = set(antecedent_tuple)
                    consequent = pattern - antecedent
                    
                    if not consequent:  # 后件不能为空
                        continue
                    
                    # 计算规则的置信度
                    confidence = self._calculate_rule_confidence(
                        antecedent, consequent, pattern, table_instances
                    )
                    
                    if confidence < min_confidence:
                        continue
                    
                    # 计算提升度
                    lift = self._calculate_lift(antecedent, consequent, confidence, table_instances)
                    
                    if lift < min_lift:
                        continue
                    
                    # 保存规则
                    rule = {
                        'antecedent': sorted(list(antecedent)),
                        'consequent': sorted(list(consequent)),
                        'rule_str': f"{sorted(list(antecedent))} → {sorted(list(consequent))}",
                        'confidence': confidence,
                        'support': pattern_support,
                        'lift': lift,
                        'pattern': sorted(list(pattern)),
                        'pattern_participation': pattern_data.get('participation_index', 0.0)
                    }
                    
                    all_rules.append(rule)
        
        # 按置信度排序
        all_rules.sort(key=lambda x: x['confidence'], reverse=True)
        
        logger.info(f"Generated {len(all_rules)} rules")
        return all_rules

