这份文档对应当前目录下的 knowledge_landmarks_toy 项目,目标是把“局部数据 + 全局 landmarks + 组合损失”这套思路彻底拆开。
说明:
README.md 和 notes.md。这个 toy 的核心执行路径是:
run.py
ExperimentConfigexperiment.pyexperiment.py
trainer.py
landmarks.py
data.py
所以这套代码的核心思想是:
baseline 只看局部窗口,knowledge-guided 还会被全局 landmarks 拉住,因此更可能在窗口外也表现稳定。
这个文件定义所有实验超参数。
1: 导入 dataclass2: 导入 Path3: 导入 Optional6: @dataclass7: class ExperimentConfig:
8: seed: int = 11
9: num_train_local: int = 32
10: num_val_global: int = 301
11: num_test_global: int = 601
12: hidden_dim: int = 48
13: batch_size: int = 24
14: epochs: int = 220
15: learning_rate: float = 1e-2
16: weight_decay: float = 1e-4
17: lambda_data: float = 0.7
lambda * L1 + (1 - lambda) * L2 的 lambda。18: center_pull_weight: float = 0.10
19: label_noise_std: float = 0.03
20: local_region_low: float = -0.7
21: local_region_high: float = 0.7
22: domain_low: float = -3.0
23: domain_high: float = 3.0
24: support_points_per_landmark: int = 40
25: landmark_set_name: str = "good"
26: device: str = "cpu"27: experiment_name: str = "default"28: results_root = 当前目录/results29: results_dir 初始可为空31-35
ensure_results_dir() 保证结果目录存在。results_root / experiment_name。37-60
to_dict() 把所有配置展开成普通字典。metrics.json 里会带完整配置,方便复现实验。这个文件定义真实函数和数据生成方式。
1: 导入 dataclass3: 导入 torch4: 导入 TensorDataset
7: @dataclass8: class DatasetBundle:
9: train_local
10: val_global
11: test_global
14-15
set_seed 只做一件事:设置 torch.manual_seed(seed)。18: def true_function(x)
19
20-25
这个设计的目的:
28: sample_local_data(...)
29
[low, high] 里均匀采样一维输入。30
31
32
33
36: make_global_grid(...)
37
torch.linspace 均匀覆盖整个定义域。38
39
(x, y)。42: create_datasets(config)
43
45-51
52-56
57-61
63-67
DatasetBundle 返回。这个文件的关键思想是:
故意制造“训练数据只在局部可见,但评估要求看全域”的设定。
这个文件是整个 toy 的知识核心。
1: 导入 asdict 和 dataclass3: 导入 torch5: 从 data.py 导入真实函数 true_function这里直接调用 true_function 的目的,是用真函数来构造“理想化知识 landmarks”。
8: @dataclass(frozen=True)9: class Landmark:
10: name
11-12: x_low, x_high
13-14: y_low, y_high
15: quality
good、shifted。18: _interval_range(x_low, x_high, padding=0.0, shift=0.0)
19
20
21
这里的 padding 和 shift 非常关键:
padding 控制“知识区间多宽”shift 控制“知识是否整体偏移”24: _base_intervals()
[-3,-2], [-2,-1], ..., [2,3]35: _build_landmarks(intervals, paddings, shifts, quality_labels)
36
37-39
40
_interval_range 算出输出区间。41-49
Landmark 对象。51
54-55
58: get_landmarks(set_name)
59
60-66
67-73
74-80
0.3581-87
0.5588
91-92
95: sample_landmark_support(...)
96
seed + 1000 建一个独立生成器,避免和训练数据完全重叠。97-100
102-107
109-114
这个文件的关键思想是:
landmark 不是孤零零的矩形,而是会被采样成一批 support 点,进入训练损失。
这个文件定义回归模型。
1: 导入 torch.nn as nn4: class TinyRegressor(nn.Module):
5
6
7-13
1 -> hiddenTanhhidden -> hiddenTanhhidden -> 115-16
self.net(x)。它的角色非常简单:
作为一个低门槛 baseline/kd 共用模型,不把复杂性放到网络结构里,而是放到损失项里。
这个文件实现训练、评估和可视化。
1-4
copy、json、math、Path6-10
matplotlib、torch、F、画矩形用的 patches、DataLoader12
14
Agg 后端15
pyplotManualAdamW 和前一个项目思路一样,原因也是一样:避免环境中 torch.optim 的问题。
18-32
34-37
39-64
67-68
_make_optimizer 统一返回 ManualAdamW71: interval_violation(pred, y_low, y_high)
72
y_low - predpred - y_high75: knowledge_distance(...)
L2 的具体实现。81
82
83
clamp_min(1e-6)。84
85
这个设计的直觉是:
88-89
sqrt(MSE) 算 RMSE。92: evaluate(model, x, y, config, support)
93
94-99
101-110
predsupport_predviolationknowledge_penalty112
113
115-121
122-125
129: train_baseline(datasets, support, config)
130
131
132
134-136
138-163
139 进入 train140-141 初始化累计量143-153 遍历训练 batch:
155
156-157
159-161
163-164
167: train_knowledge_guided(...)
168-170
172-175
177-179
181-225
188-190
191
193
194
L1 = MSE。196
197-202
L2 = knowledge_distance(...).mean()204
lambda_data * data_loss + (1 - lambda_data) * knowledge_loss205-206
208-212
214
215-218
220-222
224-225
228-230
233-252
255: _draw_landmarks(ax, landmarks)
257-258
259-268
good 或 coarse,就用橙色269
272: plot_prediction_curves(...)
273-274
276-279
281-282
284-292
294-296
这个文件的关键思想可以概括为:
baseline 只有
L1,knowledge-guided 有L1 + L2,而L2来自 landmarks 支持点上的区间惩罚。
这个文件负责组织整个实验流程。
1
2
3-10
13: run_experiment(config, save_artifacts=True, save_plots=True)
14
15
16
17
18-22
24
25
27-28
30-43
45-60
metrics.json62
run.py 打印。这个文件是单次实验入口。
1-2
4
5
6
9: parse_args()
10
ArgumentParser。11
--seed12
--epochs13
--num-train-local14
--lambda-data15
--center-pull-weight16
--landmark-setchoices=available_landmark_sets() 限制输入合法。17
--label-noise-std18
--experiment-name19
--skip-plots20
23: build_config(args)
ExperimentConfig。24-34
33
35
36
39: main()
40
41
42
run_experiment() 真正运行实验。44
45
46
47
48
51-52
data.py 故意制造“训练只在局部窗口、评估看全局”的困难场景。landmarks.py 把知识写成输入区间-输出区间对,而不是点标签。trainer.py 里的 knowledge_distance 就是 toy 版 L2。lambda_data 决定数据项和知识项谁更强。good / coarse_good / mixed / shifted_bad 让你可以系统比较知识质量的影响。如果你下一步要改这套代码,最推荐优先改:
true_functionlocal_region_low/highget_landmarksknowledge_distancelambda_data 和 center_pull_weight