Py学习  »  Python

100行用Python实现自动求导(不import任何包的情况下)

深度学习这件小事 • 4 年前 • 352 次点击  

来源|知乎  作者|而今听雨

链接|https://zhuanlan.zhihu.com/p/438685307
编辑|人工智能前沿讲习
因为看到了这个问题:
如何在python中不使用库和自动求导工具求函数导数?https://www.zhihu.com/question/501734446
觉得好玩就手撸了一个自动求导,一句import都没有,连builtin包都不引用的那种
原理很简单,构建一个树,树的leave是变量和常量,其他包括root在内的结点都是运算符结点
把常用运算符(四则运算、常用函数之类的)计算和求导都写在一个类里,就OK了
过于复杂的式子应该先跑一个DFS,把每个结点的gradient存起来避免重复计算,简单式子就无所谓了,写着玩的~
# 为了不import,我甚至连product都是自己写的=_=def product(items):    res = 1    for i in items:        res = res * i    return res
# 继承关系:# Node # Operator
# 所有的Operator都有子节点,所有的Constant和Variable都没有子结点class Node: def __init__(self, name, value=0): self.name = name self.value = value
def __eq__(self, other): return self.name == other.name
def __str__(self): return str(self.name)
def __repr__(self): return self.__str__()

class Constant(Node): def __init__(self, value): super().__init__(value, value)
def compute_value(self): return self.value
def compute_derivative(self, to_variable): return 0

class Variable(Node): def compute_value(self): return self.value
def compute_derivative(self, to_variable): if to_variable.name == self.name: return 1 else: return 0

class Operator(Node): def __init__(self, inputs, name): self.inputs = inputs self.name = f"Opt {name} of {inputs}"
def __str__(self): opt2str = {"Add": "+", "Power": "^", "Multiply": "*", "Divide": "/"} return "(" + opt2str[self.name.split(" ")[1]].join(map(str, self.inputs)) + ")"

class Add(Operator): def __init__(self, inputs): super().__init__(inputs, name="Add")
def compute_value(self): return sum(inp.compute_value() for inp in self.inputs)
def compute_derivative(self, to_variable): return sum(inp.compute_derivative(to_variable) for inp in self.inputs)

class Multiply(Operator): def __init__(self, inputs): super().__init__(inputs, name="Multiply")
def compute_value(self): return product(inp.compute_value() for inp in self.inputs)
def compute_derivative(self, to_variable): return sum( inp.compute_derivative(to_variable) * product( other_inp.compute_value() for other_inp in self.inputs if other_inp != inp ) for inp in self.inputs )

class Divide(Operator): def __init__(self, inputs): super().__init__(inputs, name="Divide")
def compute_value(self): a, b = [inp.compute_value() for inp in self.inputs] return a / b
def compute_derivative(self, to_variable): a, b = [inp.compute_value() for inp in self.inputs] da, db = [inp.compute_derivative(to_variable) for inp in self.inputs] return (da * b - db * a) / (b ** 2)

class Power(Operator): # Constant Power def __init__(self, inputs): super().__init__(inputs, name="Power")
def compute_value(self): x, n = self.inputs n = n.value return x.compute_value() ** n
def compute_derivative(self, to_variable): x, n = self.inputs n = n.value return n * (x.compute_value() ** (n - 1)) * x.compute_derivative(to_variable)

if __name__ == "__main__":    print(Add([Varaible("x"),Constant(5)]).compute_derivative())
到这里就可以work了,不过构建每个项和运算符都要实例化一个类,着实是麻烦,可以通过重写所有结点的运算符的方法来更优雅地构建较长的式子,例如像是这种:3 * (x ** 2) + 5 * x * y + 6 / x - 8 * y ** 2 + 10
def wrapper_opt(opt, self, other, r=False):    opt2class = {"add": Add, "mul": Multiply, "pow": Power, "div": Divide}    if not isinstance(other, Node):        other = Constant(other)    inputs = [other, self] if r else [self, other]    node = opt2class[opt](inputs=inputs)    return node

Node.__add__ = lambda self, other: wrapper_opt("add", self, other)Node.__mul__ = lambda self, other: wrapper_opt("mul", self, other)Node.__truediv__ = lambda self, other: wrapper_opt("div", self, other)Node.__pow__ = lambda self, other: wrapper_opt("pow", self, other)Node.__sub__ = lambda self, other: wrapper_opt( "add", self, wrapper_opt("mul", Constant(-1), other))Node.__radd__ = lambda self, other: wrapper_opt("add", self, other, r=True)Node.__rmul__ = lambda self, other: wrapper_opt("mul", self, other, r=True)Node.__rtruediv__ = lambda self, other: wrapper_opt("div", self, other, r=True)

if __name__ == "__main__": x = Variable(name="x") y = Variable(name="y") function = 3 * (x ** 2) + 5 * x * y + 6 / x - 8 * y ** 2 + 10
x.value = 18 y.value = 2 print(function.compute_value()) print(function.compute_derivative(x)) print(function)
小tip:把减法a-b定义成a+(-1*b)可以省去一个Sub运算符。其实同理除法a/b也可以定义成a*pow(b,-1)。
输出为:
5.0
1130.3333333333333
(((((3*(x^2))+((5*x)*y))+(6/x))+(-1*(8*(y^2))))+10)

技术交流群邀请函



△长按添加小助手

扫描二维码添加小助手微信

请备注:姓名-学校/公司-研究方向-城市
(如:小事-浙大-对话系统-北京)
即可申请加入深度学习/机器学习等技术交流群

为您推荐

Github大盘点!2021年最惊艳的38篇AI论文

人工智能领域有哪些曾被拒稿的优秀工作?

思考丨到底什么叫算法工程师的落地能力?

Transformer模型有多少种变体?看看这篇全面综述
各种注意力机制的PyTorch实现

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/127176