来源|知乎 作者|而今听雨
链接|https://zhuanlan.zhihu.com/p/438685307如何在python中不使用库和自动求导工具求函数导数?https://www.zhihu.com/question/501734446觉得好玩就手撸了一个自动求导,一句import都没有,连builtin包都不引用的那种原理很简单,构建一个树,树的leave是变量和常量,其他包括root在内的结点都是运算符结点把常用运算符(四则运算、常用函数之类的)计算和求导都写在一个类里,就OK了过于复杂的式子应该先跑一个DFS,把每个结点的gradient存起来避免重复计算,简单式子就无所谓了,写着玩的~# 为了不import,我甚至连product都是自己写的=_=
def product(items):
res = 1
for i in items:
res = res * i
return res
# 继承关系:
# Node
# Operator
# 所有的Operator都有子节点,所有的Constant和Variable都没有子结点
class Node:
def __init__(self, name, value=0):
self.name = name
self.value = value
def __eq__(self, other):
return self.name == other.name
def __str__(self):
return str(self.name)
def __repr__(self):
return self.__str__()
class Constant(Node):
def __init__(self, value):
super().__init__(value, value)
def compute_value(self):
return self.value
def compute_derivative(self, to_variable):
return 0
class Variable(Node):
def compute_value(self):
return self.value
def compute_derivative(self, to_variable):
if to_variable.name == self.name:
return 1
else:
return 0
class Operator(Node):
def __init__(self, inputs, name):
self.inputs = inputs
self.name = f"Opt {name} of {inputs}"
def __str__(self):
opt2str = {"Add": "+", "Power": "^", "Multiply": "*", "Divide": "/"}
return "(" + opt2str[self.name.split(" ")[1]].join(map(str, self.inputs)) + ")"
class Add(Operator):
def __init__(self, inputs):
super().__init__(inputs, name="Add")
def compute_value(self):
return sum(inp.compute_value() for inp in self.inputs)
def compute_derivative(self, to_variable):
return sum(inp.compute_derivative(to_variable) for inp in self.inputs)
class Multiply(Operator):
def __init__(self, inputs):
super().__init__(inputs, name="Multiply")
def compute_value(self):
return product(inp.compute_value() for inp in self.inputs)
def compute_derivative(self, to_variable):
return sum(
inp.compute_derivative(to_variable)
* product(
other_inp.compute_value()
for other_inp in self.inputs
if other_inp != inp
)
for inp in self.inputs
)
class Divide(Operator):
def __init__(self, inputs):
super().__init__(inputs, name="Divide")
def compute_value(self):
a, b = [inp.compute_value() for inp in self.inputs]
return a / b
def compute_derivative(self, to_variable):
a, b = [inp.compute_value() for inp in self.inputs]
da, db = [inp.compute_derivative(to_variable) for inp in self.inputs]
return (da * b - db * a) / (b ** 2)
class Power(Operator):
# Constant Power
def __init__(self, inputs):
super().__init__(inputs, name="Power")
def compute_value(self):
x, n = self.inputs
n = n.value
return x.compute_value() ** n
def compute_derivative(self, to_variable):
x, n = self.inputs
n = n.value
return n * (x.compute_value() ** (n - 1)) * x.compute_derivative(to_variable)
if __name__ == "__main__":
print(Add([Varaible("x"),Constant(5)]).compute_derivative())
到这里就可以work了,不过构建每个项和运算符都要实例化一个类,着实是麻烦,可以通过重写所有结点的运算符的方法来更优雅地构建较长的式子,例如像是这种:3 * (x ** 2) + 5 * x * y + 6 / x - 8 * y ** 2 + 10def wrapper_opt(opt, self, other, r=False):
opt2class = {"add": Add, "mul": Multiply, "pow": Power, "div": Divide}
if not isinstance(other, Node):
other = Constant(other)
inputs = [other, self] if r else [self, other]
node = opt2class[opt](inputs=inputs)
return node
Node.__add__ = lambda self, other: wrapper_opt("add", self, other)
Node.__mul__ = lambda self, other: wrapper_opt("mul", self, other)
Node.__truediv__ = lambda self, other: wrapper_opt("div", self, other)
Node.__pow__ = lambda self, other: wrapper_opt("pow", self, other)
Node.__sub__ = lambda self, other: wrapper_opt(
"add", self, wrapper_opt("mul", Constant(-1), other)
)
Node.__radd__ = lambda self, other: wrapper_opt("add", self, other, r=True)
Node.__rmul__ = lambda self, other: wrapper_opt("mul", self, other, r=True)
Node.__rtruediv__ = lambda self, other: wrapper_opt("div", self, other, r=True)
if __name__ == "__main__":
x = Variable(name="x")
y = Variable(name="y")
function = 3 * (x ** 2) + 5 * x * y + 6 / x - 8 * y ** 2 + 10
x.value = 18
y.value = 2
print(function.compute_value())
print(function.compute_derivative(x))
print(function)
小tip:把减法a-b定义成a+(-1*b)可以省去一个Sub运算符。其实同理除法a/b也可以定义成a*pow(b,-1)。5.0
1130.3333333333333
(((((3*(x^2))+((5*x)*y))+(6/x))+(-1*(8*(y^2))))+10)
想要了解更多资讯,请扫描下方二维码,关注机器学习研究会

转自: 人工智能前沿讲习