我正在尝试实现一个通用的和灵活的
__eq__
方法,该方法用于尽可能多的对象类型,包括iterables和numpy数组。
以下是我目前为止的情况:
class Environment:
def __init__(self, state):
self.state = state
def __eq__(self, other):
"""Compare two environments based on their states.
"""
if isinstance(other, self.__class__):
try:
return all(self.state == other.state)
except TypeError:
return self.state == other.state
return False
这对大多数对象类型(包括一维数组)都适用:
s = np.array(range(6))
e1 = Environment(s)
e2 = Environment(s)
e1 == e2 # True
s = 'abcdef'
e1 = Environment(s)
e2 = Environment(s)
e1 == e2 # True
问题是,当
self.state
是一个多维的numpy数组。
s = np.array(range(6)).reshape((2, 3))
e1 = Environment(s)
e2 = Environment(s)
e1 == e2
生产:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
很明显,我可以查一下
isinstance(other, np.ndarray)
然后做
(return self.state == other.state).all()
但我只是想有一种更通用的方法可以用一条语句处理所有的iterables、集合和任何类型的数组。
我也有点困惑为什么
all()
不会像这样迭代数组的所有元素
array.all()
. 有办法触发吗
np.nditer
可能会这样做?