# experiment 模块说明 ## 模块作用 在固定用户偏好(liked/disliked 特征)下,模拟多轮交互、收集反馈、评估基线、对比学习、偏好加权三种方法的准确率、精确率、召回率、F1,并支持多样性采样与绘图。 ## 文件与接口概览 - **simulator.py** — UserSimulator:`__init__(liked_features, disliked_features)`;`label_pattern(pattern) -> 0|1`;`build_ground_truth(patterns) -> Dict[int,int]` - **metrics.py** — `accuracy(y_true, y_pred)`、`precision`、`recall`、`f1`;`compute_threshold(scores) -> float` - **sampler.py** — DiversitySampler.`sample_diverse_patterns(patterns, scores, top_k, strategy) -> List[Dict]` - **evaluator.py** — InteractionEvaluator:`__init__(manager, simulator, top_k, rounds, sampling_strategy)`;`run() -> Dict`(含 accuracy/precision/recall/f1 列表及 pre_interaction) - **runner.py** — `run_experiment(baseline_config, contrastive_config, preference_weighted_config=None, liked_features, disliked_features, top_k, rounds, output_dir, sampling_strategy) -> Dict` - **plotter.py** — `plot_learning_curve(...)`、`plot_all_metrics_func(...)`、`plot_from_json(...)`:传入各方法 metrics 路径或列表,输出图片。 ## 文件与类/函数 | 文件 | 类/函数 | 说明 | |------|----------|------| | simulator.py | UserSimulator | 根据 liked/disliked 为模式打标签、构建 ground truth | | metrics.py | accuracy, precision, recall, f1, compute_threshold | 评估指标与阈值 | | sampler.py | DiversitySampler | 多种策略从排序结果中采样 top-k 模式 | | evaluator.py | InteractionEvaluator | 单种方法的评估循环(挖掘→打分→采样→反馈→更新→指标) | | runner.py | run_experiment | 多方法对比入口(baseline / contrastive / preference_weighted) | | plotter.py | plot_learning_curve, plot_all_metrics_func, plot_from_json | 学习曲线与多指标图 | --- ## 接口说明 ### 1. UserSimulator(simulator.py) #### `__init__(liked_features: Set[str], disliked_features: Set[str] | None = None)` - **传入参数**:喜欢的特征集合、不喜欢的特征集合。 - **传出参数**:无。 #### `label_pattern(pattern: Dict[str, Any]) -> int` - **传入参数**:含 `pattern` 键的模式字典(值为特征列表或逗号分隔串)。 - **传出参数**:1 表示喜欢,0 表示不喜欢/不感兴趣。 #### `build_ground_truth(patterns: List[Dict[str, Any]]) -> Dict[int, int]` - **传入参数**:模式列表。 - **传出参数**:{ 模式索引: 0|1 } 的 ground truth 字典。 --- ### 2. metrics.py #### `accuracy(y_true: Dict[int, int], y_pred: Dict[int, int]) -> float` - **传入参数**:真实标签字典、预测标签字典(索引 → 0/1)。 - **传出参数**:准确率。 #### `precision` / `recall` / `f1(y_true, y_pred) -> float` - **传入参数**:同上。 - **传出参数**:精确率 / 召回率 / F1。 #### `compute_threshold(scores: List[float]) -> float` - **传入参数**:分数列表。 - **传出参数**:用于二值化的阈值(当前实现为固定 0.7)。 --- ### 3. DiversitySampler(sampler.py) #### `sample_diverse_patterns(patterns, scores, top_k=5, strategy="mixed") -> List[Dict]` - **传入参数**:模式列表、分数列表、选取数量、策略("top" | "mixed" | "stratified" | "uncertainty")。 - **传出参数**:选中的模式子列表。 --- ### 4. InteractionEvaluator(evaluator.py) #### `__init__(manager, simulator, top_k=5, rounds=10, sampling_strategy="mixed")` - **传入参数**:PipelineManager、UserSimulator、每轮推荐数、轮数、采样策略。 - **传出参数**:无。 #### `run() -> Dict[str, Any]` - **传入参数**:无(内部用 manager 挖掘、打分、模拟反馈、更新、计算指标)。 - **传出参数**:`accuracy`、`precision`、`recall`、`f1`(每轮列表)及 `pre_interaction`(交互前单次指标)。 --- ### 5. run_experiment(runner.py) #### `run_experiment(baseline_config, contrastive_config, preference_weighted_config=None, liked_features=None, disliked_features=None, top_k=5, rounds=10, output_dir="results", sampling_strategy="mixed") -> Dict` - **传入参数**:基线/对比/偏好加权配置路径,喜欢的特征集合、不喜欢的特征集合,top_k、轮数、输出目录、采样策略。 - **传出参数**:`{ "baseline": {...}, "contrastive": {...}, "preference_weighted": {...} }`(若未跑第三种则无 preference_weighted)。 --- ### 6. plotter.py #### `plot_learning_curve(baseline, ours, output_path=None, title=..., preference_weighted=None)` - **传入参数**:基线准确率列表、对比方法准确率列表、输出路径、标题、可选第三方法列表。 - **传出参数**:无(保存图或显示)。 #### `plot_all_metrics_func(baseline_metrics, contrastive_metrics, output_path=None, preference_weighted_metrics=None)` - **传入参数**:各方法的 accuracy/precision/recall/f1 列表字典,输出路径,可选第三方法字典。 - **传出参数**:无。 #### `plot_from_json(baseline_path, contrastive_path, output_path=None, plot_all_metrics=False, preference_weighted_path=None)` - **传入参数**:两个(或三个)metrics JSON 路径,输出路径,是否画四指标图,可选第三方法路径。 - **传出参数**:无。