
来源|知乎 作者|而今听雨
链接|https://zhuanlan.zhihu.com/p/438685307如何在python中不使用库和自动求导工具求函数导数?https://www.zhihu.com/question/501734446觉得好玩就手撸了一个自动求导,一句import都没有,连builtin包都不引用的那种原理很简单,构建一个树,树的leave是变量和常量,其他包括root在内的结点都是运算符结点把常用运算符(四则运算、常用函数之类的)计算和求导都写在一个类里,就OK了过于复杂的式子应该先跑一个DFS,把每个结点的gradient存起来避免重复计算,简单式子就无所谓了,写着玩的~# 为了不import,我甚至连product都是自己写的=_=def product(items): res = 1 for i in items: res = res * i return res
# 继承关系:# Node # Operator
# 所有的Operator都有子节点,所有的Constant和Variable都没有子结点class Node: def __init__(self, name, value=0): self.name = name self.value = value
def __eq__(self, other): return self.name == other.name
def __str__(self): return str(self.name)
def __repr__(self): return self.__str__()
class Constant(Node): def __init__(self, value): super().__init__(value, value)
def compute_value(self): return self.value
def compute_derivative(self, to_variable): return 0
class Variable(Node): def compute_value(self): return self.value
def compute_derivative(self, to_variable): if to_variable.name == self.name: return 1 else: return 0
class Operator(Node): def __init__(self, inputs, name): self.inputs = inputs self.name = f"Opt {name} of {inputs}"
def __str__(self): opt2str = {"Add": "+", "Power": "^", "Multiply": "*", "Divide": "/"} return "(" + opt2str[self.name.split(" ")[1]].join(map(str, self.inputs)) + ")"
class Add(Operator): def __init__(self, inputs): super().__init__(inputs, name="Add")
def compute_value(self): return sum(inp.compute_value() for inp in self.inputs)
def compute_derivative(self, to_variable): return sum(inp.compute_derivative(to_variable) for inp in self.inputs)
class Multiply(Operator): def __init__(self, inputs): super().__init__(inputs, name="Multiply")
def compute_value(self): return product(inp.compute_value() for inp in self.inputs)
def compute_derivative(self, to_variable): return sum( inp.compute_derivative(to_variable) * product( other_inp.compute_value() for other_inp in self.inputs if other_inp != inp ) for inp in self.inputs )
class Divide(Operator):
def __init__(self, inputs): super().__init__(inputs, name="Divide")
def compute_value(self): a, b = [inp.compute_value() for inp in self.inputs] return a / b
def compute_derivative(self, to_variable): a, b = [inp.compute_value() for inp in self.inputs] da, db = [inp.compute_derivative(to_variable) for inp in self.inputs] return (da * b - db * a) / (b ** 2)
class Power(Operator): # Constant Power def __init__(self, inputs): super().__init__(inputs, name="Power")
def compute_value(self): x, n = self.inputs n = n.value return x.compute_value() ** n
def compute_derivative(self, to_variable): x, n = self.inputs n = n.value return n * (x.compute_value() ** (n - 1)) * x.compute_derivative(to_variable)
if __name__ == "__main__": print(Add([Varaible("x"),Constant(5)]).compute_derivative())
到这里就可以work了,不过构建每个项和运算符都要实例化一个类,着实是麻烦,可以通过重写所有结点的运算符的方法来更优雅地构建较长的式子,例如像是这种:3 * (x ** 2) + 5 * x * y + 6 / x - 8 * y ** 2 + 10def wrapper_opt(opt, self, other, r=False): opt2class = {"add": Add, "mul": Multiply, "pow": Power, "div": Divide} if not isinstance(other, Node): other = Constant(other) inputs = [other, self] if r else [self, other] node = opt2class[opt](inputs=inputs) return node
Node.__add__ = lambda self, other: wrapper_opt("add", self, other)Node.__mul__ = lambda self, other: wrapper_opt("mul", self, other)Node.__truediv__ = lambda self, other: wrapper_opt("div", self, other)Node.__pow__ = lambda self, other: wrapper_opt("pow", self, other)Node.__sub__ = lambda self, other: wrapper_opt( "add", self, wrapper_opt("mul", Constant(-1), other))Node.__radd__ = lambda self, other: wrapper_opt("add", self, other, r=True)Node.__rmul__ = lambda self, other: wrapper_opt("mul", self, other, r=True)Node.__rtruediv__ = lambda self, other: wrapper_opt("div", self, other, r=True)
if __name__ == "__main__": x = Variable(name="x") y = Variable(name="y") function = 3 * (x ** 2) + 5 * x * y + 6 / x - 8 * y ** 2 + 10
x.value = 18 y.value = 2 print(function.compute_value()) print(function.compute_derivative(x)) print(function)
小tip:把减法a-b定义成a+(-1*b)可以省去一个Sub运算符。其实同理除法a/b也可以定义成a*pow(b,-1)。5.0
1130.3333333333333
(((((3*(x^2))+((5*x)*y))+(6/x))+(-1*(8*(y^2))))+10)

扫描二维码添加小助手微信
