社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  Python

100行用Python实现自动求导(不import任何包的情况下)

机器学习研究组订阅 • 3 年前 • 485 次点击  

来源|知乎  作者|而今听雨

链接|https://zhuanlan.zhihu.com/p/438685307

因为看到了这个问题:
如何在python中不使用库和自动求导工具求函数导数?https://www.zhihu.com/question/501734446
觉得好玩就手撸了一个自动求导,一句import都没有,连builtin包都不引用的那种
原理很简单,构建一个树,树的leave是变量和常量,其他包括root在内的结点都是运算符结点
把常用运算符(四则运算、常用函数之类的)计算和求导都写在一个类里,就OK了
过于复杂的式子应该先跑一个DFS,把每个结点的gradient存起来避免重复计算,简单式子就无所谓了,写着玩的~
# 为了不import,我甚至连product都是自己写的=_=def product(items):    res = 1    for i in items:        res = res * i    return res
# 继承关系:# Node # Operator
# 所有的Operator都有子节点,所有的Constant和Variable都没有子结点class Node: def __init__(self, name, value=0): self.name = name self.value = value
def __eq__(self, other): return self.name == other.name
def __str__(self): return str(self.name)
def __repr__(self): return self.__str__()

class Constant(Node): def __init__(self, value): super().__init__(value, value)
def compute_value(self): return self.value
def compute_derivative(self, to_variable): return 0

class Variable(Node): def compute_value(self): return self.value
def compute_derivative(self, to_variable): if to_variable.name == self.name: return 1 else: return 0

class Operator(Node): def __init__(self, inputs, name): self.inputs = inputs self.name = f"Opt {name} of {inputs}"
def __str__(self): opt2str = {"Add": "+", "Power": "^", "Multiply": "*", "Divide": "/"} return "(" + opt2str[self.name.split(" ")[1]].join(map(str, self.inputs)) + ")"

class Add(Operator): def __init__(self, inputs): super().__init__(inputs, name="Add")
def compute_value(self): return sum(inp.compute_value() for inp in self.inputs)
def compute_derivative(self, to_variable): return sum(inp.compute_derivative(to_variable) for inp in self.inputs)

class Multiply(Operator): def __init__(self, inputs): super().__init__(inputs, name="Multiply")
def compute_value(self): return product(inp.compute_value() for inp in self.inputs)
def compute_derivative(self, to_variable): return sum( inp.compute_derivative(to_variable) * product( other_inp.compute_value() for other_inp in self.inputs if other_inp != inp ) for inp in self.inputs )

class Divide(Operator): def __init__(self, inputs): super().__init__(inputs, name="Divide")
def compute_value(self): a, b = [inp.compute_value() for inp in self.inputs] return a / b
def compute_derivative(self, to_variable): a, b = [inp.compute_value() for inp in self.inputs] da, db = [inp.compute_derivative(to_variable) for inp in self.inputs] return (da * b - db * a) / (b ** 2)

class Power(Operator): # Constant Power def __init__(self, inputs): super().__init__(inputs, name="Power")
def compute_value(self): x, n = self.inputs n = n.value return x.compute_value() ** n
def compute_derivative(self, to_variable): x, n = self.inputs n = n.value return n * (x.compute_value() ** (n - 1)) * x.compute_derivative(to_variable)

if __name__ == "__main__":    print(Add([Varaible("x"),Constant(5)]).compute_derivative())
到这里就可以work了,不过构建每个项和运算符都要实例化一个类,着实是麻烦,可以通过重写所有结点的运算符的方法来更优雅地构建较长的式子,例如像是这种:3 * (x ** 2) + 5 * x * y + 6 / x - 8 * y ** 2 + 10
def wrapper_opt(opt, self, other, r=False):    opt2class = {"add": Add, "mul": Multiply, "pow": Power, "div": Divide}    if not isinstance(other, Node):        other = Constant(other)    inputs = [other, self] if r else [self, other]    node = opt2class[opt](inputs=inputs)    return node

Node.__add__ = lambda self, other: wrapper_opt("add", self, other) Node.__mul__ = lambda self, other: wrapper_opt("mul", self, other)Node.__truediv__ = lambda self, other: wrapper_opt("div", self, other)Node.__pow__ = lambda self, other: wrapper_opt("pow", self, other)Node.__sub__ = lambda self, other: wrapper_opt( "add", self, wrapper_opt("mul", Constant(-1), other))Node.__radd__ = lambda self, other: wrapper_opt("add", self, other, r=True)Node.__rmul__ = lambda self, other: wrapper_opt("mul", self, other, r=True)Node.__rtruediv__ = lambda self, other: wrapper_opt("div", self, other, r=True)

if __name__ == "__main__": x = Variable(name="x") y = Variable(name="y") function = 3 * (x ** 2) + 5 * x * y + 6 / x - 8 * y ** 2 + 10
x.value = 18 y.value = 2 print(function.compute_value()) print(function.compute_derivative(x)) print(function)
小tip:把减法a-b定义成a+(-1*b)可以省去一个Sub运算符。其实同理除法a/b也可以定义成a*pow(b,-1)。
输出为:
5.0
1130.3333333333333
(((((3*(x^2))+((5*x)*y))+(6/x))+(-1*(8*(y^2))))+10)


想要了解更多资讯,请扫描下方二维码,关注机器学习研究会

                                          


转自: 人工智能前沿讲习

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/127183
 
485 次点击