Informed_machine_learning

knowledge_landmarks_toy 代码逐行解析

这份文档对应当前目录下的 knowledge_landmarks_toy 项目,目标是把“局部数据 + 全局 landmarks + 组合损失”这套思路彻底拆开。

说明:


1. 先看整个项目的数据流

这个 toy 的核心执行路径是:

  1. run.py
    • 解析参数
    • 构造 ExperimentConfig
    • 调用 experiment.py
  2. experiment.py
    • 设随机种子
    • 生成局部训练数据和全局验证/测试网格
    • 读取 landmarks
    • 从 landmarks 采 support points
    • 训练 baseline 和 knowledge-guided 两个模型
    • 保存指标和图
  3. trainer.py
    • baseline:只拟合局部有标签数据
    • knowledge-guided:同时拟合局部数据和 landmarks regularizer
  4. landmarks.py
    • 定义输入区间和对应输出区间
    • 构造不同质量的 landmarks 集
  5. data.py
    • 定义真实函数
    • 只在局部窗口采样训练数据
    • 在全域上生成验证/测试曲线

所以这套代码的核心思想是:

baseline 只看局部窗口,knowledge-guided 还会被全局 landmarks 拉住,因此更可能在窗口外也表现稳定。


2. config.py

这个文件定义所有实验超参数。

2.1 导入

2.2 配置类

2.3 数据规模和随机种子

2.4 模型与训练

2.5 组合损失相关

2.6 数据窗口和噪声

2.7 全域范围与 landmarks support

2.8 运行环境和输出目录

2.9 创建输出目录

2.10 转成字典


3. data.py

这个文件定义真实函数和数据生成方式。

3.1 导入与数据结构

3.2 随机种子

3.3 真实函数

这个设计的目的:

3.4 局部训练数据采样

3.5 全域网格生成

3.6 数据打包

这个文件的关键思想是:

故意制造“训练数据只在局部可见,但评估要求看全域”的设定。


4. landmarks.py

这个文件是整个 toy 的知识核心。

4.1 导入

这里直接调用 true_function 的目的,是用真函数来构造“理想化知识 landmarks”。

4.2 Landmark 数据结构

4.3 从真实函数计算输出区间

这里的 paddingshift 非常关键:

4.4 输入区间划分

4.5 根据配置批量生成 landmarks

4.6 可用 landmark set 名称

4.7 具体 landmark set 定义

good

coarse_good

mixed

shifted_bad

4.8 序列化

4.9 从 landmark 采 support 点

这个文件的关键思想是:

landmark 不是孤零零的矩形,而是会被采样成一批 support 点,进入训练损失。


5. model.py

这个文件定义回归模型。

5.1 导入

5.2 模型类

它的角色非常简单:

作为一个低门槛 baseline/kd 共用模型,不把复杂性放到网络结构里,而是放到损失项里。


6. trainer.py

这个文件实现训练、评估和可视化。

6.1 导入

6.2 手写优化器

ManualAdamW 和前一个项目思路一样,原因也是一样:避免环境中 torch.optim 的问题。

初始化

清梯度

更新一步

6.3 优化器工厂

6.4 区间出界惩罚

6.5 知识距离

这个设计的直觉是:

6.6 RMSE 计算

6.7 评估函数

6.8 baseline 训练

6.9 knowledge-guided 训练

每个 batch

每轮末尾

6.10 保存指标

6.11 训练曲线图

6.12 画 landmarks

6.13 预测曲线图

这个文件的关键思想可以概括为:

baseline 只有 L1,knowledge-guided 有 L1 + L2,而 L2 来自 landmarks 支持点上的区间惩罚。


7. experiment.py

这个文件负责组织整个实验流程。

7.1 导入

7.2 总实验入口


8. run.py

这个文件是单次实验入口。

8.1 导入

8.2 参数解析

8.3 参数构造成配置

8.4 主函数


9. 这套代码最值得你记住的 5 个点

  1. data.py 故意制造“训练只在局部窗口、评估看全局”的困难场景。
  2. landmarks.py 把知识写成输入区间-输出区间对,而不是点标签。
  3. trainer.py 里的 knowledge_distance 就是 toy 版 L2
  4. lambda_data 决定数据项和知识项谁更强。
  5. good / coarse_good / mixed / shifted_bad 让你可以系统比较知识质量的影响。

如果你下一步要改这套代码,最推荐优先改: