Loading joeynmt/DQN_loop.py +7 −4 Original line number Diff line number Diff line Loading @@ -437,14 +437,17 @@ class QManager(object): # taking the most likely action. # use the hyperparameter nu_pretrain to take the true action # or the one take from the one computed from the q_target if self.dev_network_count < self.nu_pretrain: #print ("Using pretraining...") if self.learn_step_counter % 50 == 1: print ("learn step counter: ", self.learn_step_counter) print ("dev_network_count: ", self.dev_network_count ) if self.learn_step_counter < self.nu_pretrain: if self.learn_step_counter == 1: print ("Using pretraining...") b_a_ = torch.LongTensor(b_memory[:, self.state_size+2 + self.state_size]).view(self.sample_size, 1) else: if self.dev_network_count == self.nu_pretrain: if self.learn_step_counter == self.nu_pretrain: print ("Starting using Q target net....") b_a_ = torch.LongTensor(q_next.max(1)[1].view(self.sample_size, 1).long()) #b_a_ = q_next.max(1)[0].view(self.sample_size, 1).long() # shape (batch, 1) q_eval_next = self.eval_net(b_s_).gather(1, b_a_) # shape (batch, 1) Loading test3.out 0 → 100644 +114085 −0 File added.File size exceeds preview limit. View file Loading
joeynmt/DQN_loop.py +7 −4 Original line number Diff line number Diff line Loading @@ -437,14 +437,17 @@ class QManager(object): # taking the most likely action. # use the hyperparameter nu_pretrain to take the true action # or the one take from the one computed from the q_target if self.dev_network_count < self.nu_pretrain: #print ("Using pretraining...") if self.learn_step_counter % 50 == 1: print ("learn step counter: ", self.learn_step_counter) print ("dev_network_count: ", self.dev_network_count ) if self.learn_step_counter < self.nu_pretrain: if self.learn_step_counter == 1: print ("Using pretraining...") b_a_ = torch.LongTensor(b_memory[:, self.state_size+2 + self.state_size]).view(self.sample_size, 1) else: if self.dev_network_count == self.nu_pretrain: if self.learn_step_counter == self.nu_pretrain: print ("Starting using Q target net....") b_a_ = torch.LongTensor(q_next.max(1)[1].view(self.sample_size, 1).long()) #b_a_ = q_next.max(1)[0].view(self.sample_size, 1).long() # shape (batch, 1) q_eval_next = self.eval_net(b_s_).gather(1, b_a_) # shape (batch, 1) Loading