class Agent:
def __init__(self, input_size, hidden_size, action_size, lr=0.001, gamma=0.9,
weight_decay=0.0005):
self.net = DQN_InOne(input_size, hidden_size, action_size)
self.optimizer = optim.Adam(self.net.parameters(), lr=lr, weight_decay=weight_decay)
self.gamma = gamma # discount factor
self.loss_fn = nn.MSELoss()
def get_action(self, state, action_space, epsilon=0.1):
if
np.random.random() < epsilon: # exploration
return np.random.choice(action_space), True
else:
with torch.no_grad():
q_values = self.net(state)
return torch.argmax(q_values).item(), False
def update(self, state, action, reward, next_state):
q_values = self.net(state)
with torch.no_grad():
q_next = self.net(next_state)
q_target = reward + self.gamma * torch.max(q_next)
loss = self.loss_fn(q_values[action], q_target)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()