公众号关注 “程序员遇见GitHub”
设为“星标”,重磅干货,第一时间送达

来源|知乎 作者|而今听雨
链接|https://zhuanlan.zhihu.com/p/438685307如何在python中不使用库和自动求导工具求函数导数?https://www.zhihu.com/question/501734446觉得好玩就手撸了一个自动求导,一句import都没有,连builtin包都不引用的那种原理很简单,构建一个树,树的leave是变量和常量,其他包括root在内的结点都是运算符结点把常用运算符(四则运算、常用函数之类的)计算和求导都写在一个类里,就OK了过于复杂的式子应该先跑一个DFS,把每个结点的gradient存起来避免重复计算,简单式子就无所谓了,写着玩的~# 为了不import,我甚至连product都是自己写的=_=
def product(items):
res = 1
for i in items:
res = res * i
return res
# 继承关系:
# Node
# Operator
# 所有的Operator都有子节点,所有的Constant和Variable都没有子结点
class Node:
def __init__(self, name, value=0):
self.name = name
self.value = value
def __eq__(self, other):
return self.name == other.name
def __str__(self):
return str(self.name)
def __repr__(self):
return self.__str__()
class Constant(Node):
def __init__(self, value):
super().__init__(value, value)
def compute_value(self):
return self.value
def compute_derivative(self, to_variable):
return 0
class Variable(Node):
def compute_value(self):
return self.value
def compute_derivative(self, to_variable):
if to_variable.name == self.name:
return 1
else:
return 0
class Operator(Node):
def __init__(self, inputs, name):
self.inputs = inputs
self.name = f"Opt {name} of {inputs}"
def __str__(self):
opt2str = {"Add": "+", "Power": "^", "Multiply": "*", "Divide": "/"}
return "(" + opt2str[self.name.split(" ")[1]].join(map(str, self.inputs)) + ")"
class Add(Operator):
def __init__(self, inputs):
super().__init__(inputs, name="Add")
def compute_value(self):
return sum(inp.compute_value() for inp in self.inputs)
def compute_derivative(self, to_variable):
return sum(inp.compute_derivative(to_variable) for inp in self.inputs)
class Multiply(Operator):
def __init__(self, inputs):
super().__init__(inputs, name="Multiply")
def compute_value(self):
return product(inp.compute_value() for inp in self.inputs)
def compute_derivative(self, to_variable):
return sum(
inp.compute_derivative(to_variable)
* product(
other_inp.compute_value()
for other_inp in self.inputs
if other_inp != inp
)
for inp in self.inputs
)
class Divide(Operator):
def __init__(self, inputs):
super().__init__(inputs, name="Divide")
def compute_value(self):
a, b = [inp.compute_value() for inp in self.inputs]
return a / b
def compute_derivative(self, to_variable):
a, b = [inp.compute_value() for inp in self.inputs]
da, db = [inp.compute_derivative(to_variable) for inp in self.inputs]
return (da * b - db * a) / (b ** 2)
class Power(Operator):
# Constant Power
def __init__(self, inputs):
super().__init__(inputs, name="Power")
def compute_value(self):
x, n = self.inputs
n = n.value
return x.compute_value() ** n
def compute_derivative(self, to_variable):
x, n = self.inputs
n = n.value
return n * (x.compute_value() ** (n - 1)) * x.compute_derivative(to_variable)
if __name__ == "__main__":
print(Add([Varaible("x"),Constant(5)]).compute_derivative())
到这里就可以work了,不过构建每个项和运算符都要实例化一个类,着实是麻烦,可以通过重写所有结点的运算符的方法来更优雅地构建较长的式子,例如像是这种:3 * (x ** 2) + 5 * x * y + 6 / x - 8 * y ** 2 + 10def wrapper_opt(opt, self, other, r=False):
opt2class = {"add": Add, "mul": Multiply, "pow": Power, "div": Divide}
if not isinstance(other, Node):
other = Constant(other)
inputs = [other, self] if r else [self, other]
node = opt2class[opt](inputs=inputs)
return node
Node.__add__ = lambda self, other: wrapper_opt("add", self, other)
Node.__mul__ = lambda self, other: wrapper_opt("mul", self, other)
Node.__truediv__ = lambda self, other: wrapper_opt("div", self, other)
Node.__pow__ = lambda self, other: wrapper_opt("pow", self, other)
Node.__sub__ = lambda self, other: wrapper_opt(
"add", self, wrapper_opt("mul", Constant(-1), other)
)
Node.__radd__ = lambda self, other: wrapper_opt("add", self, other, r=True)
Node.__rmul__ = lambda self, other: wrapper_opt("mul", self, other, r=True)
Node.__rtruediv__ = lambda self, other: wrapper_opt("div", self, other, r=True)
if __name__ == "__main__":
x = Variable(name="x")
y = Variable(name="y")
function = 3 * (x ** 2) + 5 * x * y + 6 / x - 8 * y ** 2 + 10
x.value = 18
y.value = 2
print(function.compute_value())
print(function.compute_derivative(x))
print(function)
小tip:把减法a-b定义成a+(-1*b)可以省去一个Sub运算符。其实同理除法a/b也可以定义成a*pow(b,-1)。5.0
1130.3333333333333
(((((3*(x^2))+((5*x)*y))+(6/x))+(-1*(8*(y^2))))+10)推荐阅读:
我教你如何读博!
牛逼!轻松高效处理文本数据神器
B站强化学习大结局!
如此神器,得之可得顶会!
兄弟们!神经网络画图,有它不愁啊
太赞了!东北大学朱靖波,肖桐团队开源《机器翻译:统计建模与深度学习方法》
当年毕业答辩!遗憾没有它...
已开源!所有李航老师《统计学习方法》代码实现
这个男人,惊为天人!手推PRML!
它来了!《深度学习》(花书) 数学推导、原理剖析与代码实现
你们心心念念的MIT教授Gilbert Strang线性代数彩板笔记!强烈推荐!
GitHub超过9800star!学习Pytorch,有这一份资源就够了!强推!
你真的懂神经网络?强推一个揭秘神经网络的工具,ANN Visualizer
诸位!看我如何白嫖2020 icassp!
这个时代研究情感分析,是最好也是最坏!
BERT雄霸天下!
玩转Pytorch,搞懂这个教程就可以了,从GAN到词嵌入都有实例
是他,是他,就是他!宝藏博主让你秒懂Transformer、BERT、GPT!
fitlog!复旦邱锡鹏老师组内部调参工具!一个可以节省一篇论文的调参利器
Github开源!查阅arXiv论文新神器,一行代码比较版本差别,我爱了!
开源!数据结构与算法必备的 50 个代码实现
他来了!吴恩达带着2018机器学习入门高清视频,还有习题解答和课程拓展来了!
太赞了!复旦邱锡鹏老师NLP实战code解读开源!
这块酷炫的Python神器!我真的爱了,帮助你深刻理解语言本质!实名推荐!
论文神器!易搜搭
不瞒你说!这可能是世界上最好的线性代数教程