Pytorch 学习笔记(8): PyTorch FX

张开发
2026/6/12 3:33:05 15 分钟阅读
Pytorch 学习笔记(8): PyTorch FX
一、FX 是什么FX是 PyTorch 提供的模型转换工具包核心功能是将nn.Module转换为可分析、可修改的中间表示IR再生成新的 Python 代码。FX 三大核心组件组件作用符号追踪器 (Symbolic Tracer)通过符号执行捕获模型语义中间表示 (Graph/IR)用 Graph 结构表示计算流程Python 代码生成将修改后的 Graph 转回可执行代码快速示例importtorchfromtorch.fximportsymbolic_traceclassMyModule(torch.nn.Module):def__init__(self):super().__init__()self.paramtorch.nn.Parameter(torch.rand(3,4))self.lineartorch.nn.Linear(4,5)defforward(self,x):returnself.linear(xself.param).clamp(min0.0,max1.0)# 1. 符号追踪moduleMyModule()symbolic_tracedsymbolic_trace(module)# 2. 查看中间表示Graphprint(symbolic_traced.graph) graph(): %x : [num_users1] placeholder[targetx] %param : [num_users1] get_attr[targetparam] %add : [num_users1] call_function[targetoperator.add](...) %linear : [num_users1] call_module[targetlinear](...) %clamp : [num_users1] call_method[targetclamp](...) return clamp # 3. 查看生成的代码print(symbolic_traced.code) def forward(self, x): param self.param add x param; x param None linear self.linear(add); add None clamp linear.clamp(min0.0, max1.0); linear None return clamp 二、Graph 结构详解FX 的 Graph 由Node组成每个 Node 代表一个操作。Node 的六种操作类型opcodeopcode含义示例placeholder函数输入参数%x placeholder[targetx]get_attr获取模块属性/参数%param get_attr[targetparam]call_function调用函数如torch.addcall_function[targetoperator.add]call_module调用子模块的 forwardcall_module[targetlinear]call_method调用 Tensor 方法call_method[targetclamp]output返回值return clamp打印 Graph 表格gm.graph.print_tabular()输出opcode name target args kwargs ------------- ------ ----------------------- ---------- -------- placeholder x x () {} get_attr linear_weight linear.weight () {} call_function add_1 built-in function add (x, ...) {} call_module linear_1 linear (add_1,) {} ...三、编写 FX 转换Transformation标准转换模板importtorchimporttorch.fxasfxdeftransform(m:torch.nn.Module,tracer_class:typefx.Tracer)-torch.nn.Module:# Step 1: 获取 Graphgraphtracer_class().trace(m)# Step 2: 修改 Graph# ... 转换逻辑 ...# Step 3: 返回新的 GraphModulereturnfx.GraphModule(m,graph)直接修改 Graph 示例替换算子deftransform(m:torch.nn.Module)-torch.nn.Module:graphfx.Tracer().trace(m)fornodeingraph.nodes:# 找到 torch.add 调用替换为 torch.mulifnode.opcall_functionandnode.targettorch.add:node.targettorch.mul# 直接修改目标函数graph.lint()# 检查 Graph 合法性returnfx.GraphModule(m,graph)插入新节点示例# 在指定节点后插入 ReLUwithtraced.graph.inserting_after(node):new_nodetraced.graph.call_function(torch.relu,args(node,))# 将所有使用原节点的地方替换为新节点node.replace_all_uses_with(new_node)四、高级转换技巧1. 子图重写replace_patternFX 提供查找-替换功能自动匹配并替换子图fromtorch.fximportreplace_pattern# 定义模式要查找的子图defpattern(w1,w2):returntorch.cat([w1,w2])# 定义替换新的子图defreplacement(w1,w2):returntorch.stack([w1,w2])# 执行替换matchesreplace_pattern(gm,pattern,replacement)2. Proxy 重追踪用 Proxy 机制自动记录操作避免手动 Graph 操作defrelu_decomposition(x):将 ReLU 分解为 (x 0) * xreturn(x0)*x decomposition_rules{F.relu:relu_decomposition}defdecompose(model):graphfx.Tracer().trace(model)new_graphfx.Graph()env{}tracertorch.fx.proxy.GraphAppendingTracer(new_graph)fornodeingraph.nodes:ifnode.opcall_functionandnode.targetindecomposition_rules:# 用 Proxy 包装参数自动记录操作proxy_args[fx.Proxy(env[x.name],tracer)ifisinstance(x,fx.Node)elsexforxinnode.args]output_proxydecomposition_rules[node.target](*proxy_args)env[node.name]output_proxy.nodeelse:new_nodenew_graph.node_copy(node,lambdax:env[x.name])env[node.name]new_nodereturnfx.GraphModule(model,new_graph)3. Interpreter 模式逐节点执行 Graph适合分析和转换classShapeProp:形状传播记录每个节点的 shape 和 dtypedef__init__(self,mod):self.modmod self.graphmod.graph self.modulesdict(self.mod.named_modules())defpropagate(self,*args):args_iteriter(args)env{}defload_arg(a):returntorch.fx.graph.map_arg(a,lambdan:env[n.name])fornodeinself.graph.nodes:ifnode.opplaceholder:resultnext(args_iter)elifnode.opget_attr:resultfetch_attr(node.target)elifnode.opcall_function:resultnode.target(*load_arg(node.args),**load_arg(node.kwargs))# ... 其他操作类型 ...# 记录 shape 和 dtypeifisinstance(result,torch.Tensor):node.shaperesult.shape node.dtyperesult.dtype env[node.name]resultreturnload_arg(self.graph.result)五、Transformer 类Transformer是Interpreter的子类用于生成新 GraphclassNegSigmSwapXformer(fx.Transformer):defcall_function(self,target,args,kwargs):iftargetistorch.sigmoid:returntorch.neg(*args,**kwargs)returnsuper().call_function(target,args,kwargs)defcall_method(self,target,args,kwargs):iftargetneg:call_self,*args_tailargsreturncall_self.sigmoid(*args_tail,**kwargs)returnsuper().call_method(target,args,kwargs)# 使用gmfx.symbolic_trace(fn)transformedNegSigmSwapXformer(gm).transform()六、调试技巧1. 检查转换正确性不要用比较 Tensor用torch.allclose()# ❌ 错误assertoriginal(input)transformed(input)# ✅ 正确asserttorch.allclose(original(input),transformed(input))2. 调试生成的代码# 方法1打印代码print(traced.code)# 方法2导出到文件夹traced.to_folder(output_folder,ModuleName)fromoutput_folderimportModuleName# 方法3使用 pdbimportpdb;pdb.set_trace()traced(input)# 单步调试3. 可视化 Graph# 打印表格形式traced.graph.print_tabular()# 打印 Graph 结构print(traced.graph)七、符号追踪的限制❌ 动态控制流不支持deffunc_to_trace(x):ifx.sum()0:# ❌ 错误条件依赖输入值returntorch.relu(x)else:returntorch.neg(x)# 报错TraceError: symbolically traced variables cannot be used as inputs to control flow✅ 静态控制流支持classMyModule(torch.nn.Module):def__init__(self,do_activationFalse):super().__init__()self.do_activationdo_activation# 超参数非输入defforward(self,x):xself.linear(x)ifself.do_activation:# ✅ 正确条件不依赖输入xtorch.relu(x)returnx解决方案concrete_args# 用 concrete_args 绑定具体值fx.symbolic_trace(f,concrete_args{flag:True})八、常用 API 速查API功能symbolic_trace(root, concrete_argsNone)符号追踪wrap(fn_or_name)注册叶子函数GraphModule(root, graph)从 Graph 创建模块graph.call_function(fn, args, kwargs)插入函数调用节点graph.call_module(module_name, args, kwargs)插入模块调用节点node.replace_all_uses_with(new_node)替换所有使用graph.lint()检查 Graph 合法性gm.recompile()重新编译 forwardreplace_pattern(gm, pattern, replacement)子图替换Interpreter.run(*args)解释执行 GraphTransformer.transform()转换并返回新模块九、最佳实践转换后调用graph.lint()- 确保 Graph 结构合法修改后调用gm.recompile()- 同步生成 forward 代码用torch.allclose()验证正确性- 浮点数比较避免 set 迭代- 用 dict 保持确定性顺序标记叶子模块- 对训练标志敏感的模块用is_leaf_module结语FX 是 PyTorch 模型优化的基础设施广泛应用于量化Quantization算子融合Operator Fusion剪枝Pruning分布式训练优化掌握 FX 可以让你深入理解 PyTorch 模型的内部结构实现自定义的编译优化流程。参考资源PyTorch FX 官方文档FX 示例代码库

更多文章