社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  Python

100行用Python实现自动求导(不import任何包的情况下)

程序员遇见GitHub • 3 年前 • 403 次点击  

公众号关注 “程序员遇见GitHub

设为“星标”,重磅干货,第一时间送达

来源|知乎  作者|而今听雨

链接|https://zhuanlan.zhihu.com/p/438685307
编辑|人工智能前沿讲习
因为看到了这个问题:
如何在python中不使用库和自动求导工具求函数导数?https://www.zhihu.com/question/501734446
觉得好玩就手撸了一个自动求导,一句import都没有,连builtin包都不引用的那种
原理很简单,构建一个树,树的leave是变量和常量,其他包括root在内的结点都是运算符结点
把常用运算符(四则运算、常用函数之类的)计算和求导都写在一个类里,就OK了
过于复杂的式子应该先跑一个DFS,把每个结点的gradient存起来避免重复计算,简单式子就无所谓了,写着玩的~
# 为了不import,我甚至连product都是自己写的=_=def product(items):    res = 1    for i in items:        res = res * i    return res
# 继承关系:# Node # Operator
# 所有的Operator都有子节点,所有的Constant和Variable都没有子结点class Node: def __init__(self, name, value=0): self.name = name self.value = value
def __eq__(self, other): return self.name == other.name
def __str__(self): return str(self.name)
def __repr__(self): return self.__str__()

class Constant(Node): def __init__(self, value): super().__init__(value, value)
def compute_value(self): return self.value
def compute_derivative(self, to_variable): return 0

class Variable(Node): def compute_value(self): return self.value
def compute_derivative(self, to_variable): if to_variable.name == self.name: return 1 else: return 0

class Operator(Node): def __init__(self, inputs, name): self.inputs = inputs self.name = f"Opt {name} of {inputs}"
def __str__(self): opt2str = {"Add": "+", "Power": "^", "Multiply": "*", "Divide": "/"} return "(" + opt2str[self.name.split(" ")[1]].join(map(str, self.inputs)) + ")"

class Add(Operator): def __init__(self, inputs): super().__init__(inputs, name="Add")
def compute_value(self): return sum(inp.compute_value() for inp in self.inputs)
def compute_derivative(self, to_variable): return sum(inp.compute_derivative(to_variable) for inp in self.inputs)

class Multiply(Operator): def __init__(self, inputs): super().__init__(inputs, name="Multiply")
def compute_value(self): return product(inp.compute_value() for inp in self.inputs)
def compute_derivative(self, to_variable): return sum( inp.compute_derivative(to_variable) * product( other_inp.compute_value() for other_inp in self.inputs if other_inp != inp ) for inp in self.inputs )

class Divide(Operator): def __init__(self, inputs): super().__init__(inputs, name="Divide")
def compute_value(self): a, b = [inp.compute_value() for inp in self.inputs] return a / b
def compute_derivative(self, to_variable): a, b = [inp.compute_value() for inp in self.inputs] da, db = [inp.compute_derivative(to_variable) for inp in self.inputs] return (da * b - db * a) / (b ** 2)

class Power(Operator): # Constant Power def __init__(self, inputs): super().__init__(inputs, name="Power")
def compute_value(self): x, n = self.inputs n = n.value return x.compute_value() ** n
def compute_derivative(self, to_variable): x, n = self.inputs n = n.value return n * (x.compute_value() ** (n - 1)) * x.compute_derivative(to_variable)

if __name__ == "__main__":    print(Add([Varaible("x"),Constant(5)]).compute_derivative())
到这里就可以work了,不过构建每个项和运算符都要实例化一个类,着实是麻烦,可以通过重写所有结点的运算符的方法来更优雅地构建较长的式子,例如像是这种:3 * (x ** 2) + 5 * x * y + 6 / x - 8 * y ** 2 + 10
def wrapper_opt(opt, self, other, r=False):    opt2class = {"add": Add, "mul": Multiply, "pow": Power, "div": Divide}    if not isinstance(other, Node):        other = Constant(other)    inputs = [other, self] if r else [self, other]    node = opt2class[opt](inputs=inputs)    return node

Node.__add__ = lambda self, other: wrapper_opt("add", self, other)Node.__mul__ = lambda self, other: wrapper_opt("mul", self, other)Node.__truediv__ = lambda self, other: wrapper_opt("div", self, other)Node.__pow__ = lambda self, other: wrapper_opt("pow", self, other)Node.__sub__ = lambda self, other: wrapper_opt( "add", self, wrapper_opt("mul", Constant(-1), other))Node.__radd__ = lambda self, other: wrapper_opt("add", self, other, r=True)Node.__rmul__ = lambda self, other: wrapper_opt("mul", self, other, r=True)Node.__rtruediv__ = lambda self, other: wrapper_opt("div", self, other, r=True)

if __name__ == "__main__": x = Variable(name="x") y = Variable(name="y") function = 3 * (x ** 2) + 5 * x * y + 6 / x - 8 * y ** 2 + 10
x.value = 18 y.value = 2 print(function.compute_value()) print(function.compute_derivative(x)) print(function)
小tip:把减法a-b定义成a+(-1*b)可以省去一个Sub运算符。其实同理除法a/b也可以定义成a*pow(b,-1)。
输出为:
5.0
1130.3333333333333
(((((3*(x^2))+((5*x)*y))+(6/x))+(-1*(8*(y^2))))+10)

推荐阅读:

我教你如何读博!

牛逼!轻松高效处理文本数据神器

B站强化学习大结局!

如此神器,得之可得顶会!

兄弟们!神经网络画图,有它不愁啊

太赞了!东北大学朱靖波,肖桐团队开源《机器翻译:统计建模与深度学习方法》

当年毕业答辩!遗憾没有它...

已开源!所有李航老师《统计学习方法》代码实现

这个男人,惊为天人!手推PRML!

它来了!《深度学习》(花书) 数学推导、原理剖析与代码实现

你们心心念念的MIT教授Gilbert Strang线性代数彩板笔记!强烈推荐!

GitHub超过9800star!学习Pytorch,有这一份资源就够了!强推!

你真的懂神经网络?强推一个揭秘神经网络的工具,ANN Visualizer

诸位!看我如何白嫖2020 icassp!

这个时代研究情感分析,是最好也是最坏!

BERT雄霸天下!

玩转Pytorch,搞懂这个教程就可以了,从GAN到词嵌入都有实例

是他,是他,就是他!宝藏博主让你秒懂Transformer、BERT、GPT!

fitlog!复旦邱锡鹏老师组内部调参工具!一个可以节省一篇论文的调参利器

Github开源!查阅arXiv论文新神器,一行代码比较版本差别,我爱了!

开源!数据结构与算法必备的 50 个代码实现

他来了!吴恩达带着2018机器学习入门高清视频,还有习题解答和课程拓展来了!

太赞了!复旦邱锡鹏老师NLP实战code解读开源!

这块酷炫的Python神器!我真的爱了,帮助你深刻理解语言本质!实名推荐!

论文神器!易搜搭

不瞒你说!这可能是世界上最好的线性代数教程

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/127179
 
403 次点击