这份文档对应当前目录下的 toy 项目代码,目标是把代码从“能跑”解释成“你能自己改”。
说明:
README.md 和 notes.md。这个 toy 的执行路径是:
run.py
ExperimentConfigexperiment.py 里的 run_experimentexperiment.py
trainer.py 评估和画图trainer.py
rules.py 动态构造的 soft targetrules.py
data.py
所以你可以把整个项目压成一句话:
先造二维数据,再把规则转成 teacher 分布,然后让 student 同时学标签和 teacher。
这个文件只做一件事:集中定义实验配置。
1: from dataclasses import dataclass
dataclass,后面用它把配置类写成简洁的“字段集合”。2: from pathlib import Path
Path 处理结果目录路径。3: from typing import Optional
results_dir 初始可以是 None。6: @dataclass
7: class ExperimentConfig:
8: seed: int = 7
9: num_labeled: int = 64
10: num_unlabeled: int = 512
11: num_val: int = 256
12: num_test: int = 512
13: hidden_dim: int = 32
14: batch_size_labeled: int = 32
15: batch_size_unlabeled: int = 64
16: epochs: int = 140
17: learning_rate: float = 1e-2
18: weight_decay: float = 1e-4
19: rule_temperature: float = 2.0
20: rule_strength: float = 1.25
21: rule_set_name: str = "single_good"
22: max_distill_weight: float = 0.65
23: ramp_up_epochs: int = 40
24: mesh_step: float = 0.04
25: device: str = "cpu"
26: experiment_name: str = "default"
27: results_root: Path = Path(__file__).resolve().parent / "results"
results/。28: results_dir: Optional[Path] = None
30: def ensure_results_dir(self) -> Path:
31-32
results_dir,就用 results_root / experiment_name 作为默认目录。33
parents=True 允许自动创建多级目录。34
36: def to_dict(self) -> dict:
metrics.json。37-58
results_dir 这里转成字符串,因为 JSON 不能直接序列化 Path。这个文件的核心思想很简单:
所有超参数统一放在一个对象里,训练、评估、画图都只依赖这个对象。
这个文件负责生成 toy 数据。
1: 导入 dataclass3: 导入 torch4: 导入 TensorDataset
7: @dataclass8: class DatasetBundle:
9: labeled: TensorDataset
10: unlabeled: TensorDataset
11: val: TensorDataset
12: test: TensorDataset
15: def set_seed(seed: int) -> None:
16
torch.manual_seed(seed)。19: def decision_score(x: torch.Tensor) -> torch.Tensor:
20: x1 = x[:, 0]
21: x2 = x[:, 1]
22
这一行的意义是:
25: def make_labels(x: torch.Tensor) -> torch.Tensor:
26
decision_score(x) > 0.0 为真记为 1,否则为 0。.long() 把布尔值转成整型类别标签。29: _sample_candidate_pool(...)
30
[-2, 2] x [-2, 2] 内均匀采样二维点。31
32
35: def make_balanced_dataset(...)
36: half = num_points // 2
37-38
40-43
45
46
48-49
51-56
58-63
65
66
67
68
71: def create_datasets(config) -> DatasetBundle:
72
torch.Generator。74-77
79-84
DatasetBundle 返回。这一整个文件要表达的是:
先定义一个真实但不太复杂的二维分类边界,再人为制造“少标注 + 多无标签”的环境。
这个文件非常简单,只定义模型。
1: import torch.nn as nn
4: class TinyMLP(nn.Module):
5: def __init__(self, hidden_dim: int = 32) -> None:
6: super().__init__()
7-13
nn.Sequential 堆出一个两层隐藏层的网络:8: 输入 2 维 -> 隐藏层9: Tanh10: 隐藏层 -> 隐藏层11: Tanh12: 隐藏层 -> 2 类输出 logits15: def forward(self, x):
16
self.net。这里的设计意图是:
这个文件是项目的关键,因为它把“规则”变成了“teacher 概率分布”。
1: 导入 asdict 和 dataclass
RuleSpec 需要 dataclass,后面写 JSON 要用 asdict3: 导入 torch4: 导入 torch.nn.functional as F7: @dataclass(frozen=True)
frozen=True 表示规则定义后不可修改。8: class RuleSpec:
9: name
10: description
11: coefficients
(a, b)。12: bias
13: positive_class
14: weight
15: temperature_scale
18: RULE_SETS = {
single_good19-27
x1 > x2 时偏向 class 1。single_bad28-36
multi_good37-61
diag_positive:对角线规则。x1_positive:x1 > 0.15 倾向 class 1。x2_small:x2 < 0.10 倾向 class 1。multi_mixed62-86
multi_bad87-111
115-116
119-123
125-126
129: def rule_margin(x, rule):
130
a*x1 + b*x2 + c。这个 margin 的意义是:
133: soft_rule_probability_for_rule(...)
134
135
(0, 1) 概率。136-138
1 - sigmoid。141: rule_distribution(...)
[p(class0), p(class1)]。142
143
torch.stack 拼成二分类分布。146-147
150: aggregate_rule_distribution(...)
151-152
154
log_rule,形状是 (batch_size, 2)。155-157
158
这里的设计相当于:
多条规则在 log-probability 空间做加权融合。
161-162
argmax 得到多规则最终预测。165-171
build_teacher_probs(...)。172
173
175-176
177
teacher_logits = log_student + rule_strength * log_rule178
181-183
diag_positive->c1。这个文件的核心思想可以压缩成一句话:
规则先变成概率分布,多条规则先聚合,再和当前 student 概率融合,形成 teacher。
这是整个项目最重要的文件,真正实现训练、评估和画图。
1-5
7-10
matplotlib、torch、损失函数接口和 DataLoader。12-18
20
matplotlib 使用 Agg 后端,这样即使没有 GUI 也能存图。21
pyplot。ManualAdamW这个类是因为当前环境里 torch.optim 受 sympy 问题影响,所以这里手写了一个简版 AdamW。
24: 定义 ManualAdamW25-30
26
31
32-38
exp_avgexp_avg_sq40-43
45
step() 真正执行参数更新。46-48
50-70
73-74
_to_device:把一个 batch 中的每个张量都搬到目标设备。77-79
distill_weight_at:蒸馏权重调度器。ramp_up_epochs 内线性增长,之后保持 max_distill_weight。82: evaluate(model, dataset, config, rule_specs)
83
84
TensorDataset 里取出整份 x, y。85-86
87-90
92
93
94-101
103
104
105
107-118
120-126
129-134
_make_optimizer:统一返回 ManualAdamW。137: train_baseline(...)
138
139
140
142-144
146-167
147 进入 train 模式148 初始化 epoch loss149-156 遍历 labeled batch
158 算平均训练损失159 在验证集上评估160-161 记录历史163-165 如果验证准确率更好,就保存当前模型167-168
171: train_logic_guided(...)
172-175
177-179
181-239
182
183
pi_t。184-185
itertools.cycle 把 loader 变成可循环迭代器。186
193-195
197
199
200
202
203
204-211
torch.no_grad() 下构造 teacher:
206 用当前 student logits207 用当前 batch 样本208 用规则集209-210 用规则强度和温度213-217
219
(1 - pi_t) * supervised_loss + pi_t * distill_loss220-221
223-227
229
230-233
235-237
239-240
243-245
metrics.json。248-266
269: _plot_rule_boundaries(...)
270
271-283
a, b, cb != 0,画斜线 a*x + b*y + c = 0b == 0 但 a != 0,画竖线286: plot_decision_boundaries(...)
287-288
290-295
297-298
300-312
contourf 画概率热图contour 画 0.5 决策边界314-316
这个文件是整个项目的核心,因为它把“论文思想”真正变成了:
这个文件是实验编排层。
1-8
trainer.py 导入训练、评估、画图、存指标函数。9
data.py 导入数据生成和设种子。10
rules.py 导入规则查询和序列化工具。13: run_experiment(config, save_artifacts=True, save_plots=True)
14
15
16
17
19
20
22-23
24-33
35-50
metrics.json52
run.py 或 sweep.py 使用。这个文件的作用是:
把“配置 -> 数据 -> 训练 -> 评估 -> 落盘”串成一条统一流程。
这个文件是单次实验入口。
1-2
4
5
6
9: def parse_args():
10
ArgumentParser。11
--seed12
--epochs13
--num-labeled14
--num-unlabeled15
--rule-strength16
--max-distill-weight17
--rule-setchoices=available_rule_sets() 限制只能用已知规则集。18
--experiment-name19
--skip-plots20
23: def build_config(args) -> ExperimentConfig:
24-34
ExperimentConfig。33
35
36
39: def main():
40
41
42
44
45
46-48
51-52
main()。这个文件负责批量参数扫描。
1-5
7
8
9
12: parse_csv_list(raw_value, cast_fn)
"7,13" 这种字符串拆成列表。13
cast_fn 转型。16-32
nameepochsseedsnum-labeled-valuesrule-strengthsdistill-weightsrule-setsnum-unlabeledsave-plots35-40
43: save_summary(rows, output_dir)
44-45
47-48
50-54
57: def main():
58
59-63
65-66
68
69-71
itertools.product 枚举所有超参数组合。73
74-77
78-88
ExperimentConfig。90
91-107
108-112
114
delta_accuracy 从大到小排序。115
117-120
best_run.json。121
124-125
这个文件的意义是:
把单次 toy 实验变成了可做 ablation、可比较不同规则集和超参数的实验平台。
data.py 定义了一个“少标注 + 多无标签”的 toy 场景。rules.py 的核心不是规则本身,而是“把规则变成 teacher 概率分布”。trainer.py 的 logic-guided 版本本质是监督损失和 KL 蒸馏损失的组合。experiment.py 负责把数据、规则、训练和画图串起来。sweep.py 让这个 toy 不再只是 demo,而是可以做系统实验。如果你下一步要改代码,最推荐先改的地方是:
RULE_SETSbuild_teacher_probsdistill_weight_atsweep.py 里的扫描维度