@@ -76,7 +76,7 @@ def __init__(self, env, scope="estimator", summaries_dir=None):
7676 self .num_actions = env .action_space .n
7777 self .epsilon = initial_epsilon
7878 self .epsilon_step = (initial_epsilon - final_epsilon ) / exploration_steps
79-
79+
8080 # Writes Tensorboard summaries to disk
8181 self .summary_writer = None
8282 with tf .variable_scope (scope ):
@@ -110,7 +110,7 @@ def build_model(self):
110110
111111 a_one_hot = tf .one_hot (self .actions , self .num_actions , 1.0 , 0.0 )
112112 q_value = tf .reduce_sum (tf .multiply (self .predictions , a_one_hot ), reduction_indices = 1 )
113-
113+
114114 # Calculate the loss
115115 self .losses = tf .squared_difference (self .y , q_value )
116116 self .loss = tf .reduce_mean (self .losses )
@@ -176,24 +176,24 @@ def copy_model_parameters(estimator1, estimator2):
176176 return update_ops
177177
178178def create_memory (env ):
179- # Populate the replay memory with initial experience
179+ # Populate the replay memory with initial experience
180180 replay_memory = []
181-
181+
182182 frame = env .reset ()
183183 state = get_initial_state (frame )
184184
185185 for i in range (replay_memory_init_size ):
186186 action = np .random .choice (np .arange (env .action_space .n ))
187187 frame , reward , done , _ = env .step (action )
188-
188+
189189 next_state = np .append (state [1 :, :, :], preprocess (frame ), axis = 0 )
190190 replay_memory .append (Transition (state , action , reward , next_state , done ))
191191 if done :
192192 frame = env .reset ()
193193 state = get_initial_state (frame )
194194 else :
195195 state = next_state
196-
196+
197197 return replay_memory
198198
199199
@@ -222,29 +222,30 @@ def setup_summary():
222222
223223 # Create a glboal step variable
224224 global_step = tf .Variable (0 , name = 'global_step' , trainable = False )
225-
225+
226226 # Create estimators
227227 q_estimator = Estimator (env , scope = "q" , summaries_dir = tensorboard_path )
228228 target_estimator = Estimator (env , scope = "target_q" )
229-
229+
230230 copy_model = copy_model_parameters (q_estimator , target_estimator )
231-
231+
232232 summary_placeholders , update_ops , summary_op = setup_summary ()
233233
234234 # The replay memory
235235 replay_memory = create_memory (env )
236-
236+
237237 with tf .Session () as sess :
238238 sess .run (tf .global_variables_initializer ())
239239
240-
240+ q_estimator .summary_writer .add_graph (sess .graph )
241+
241242 saver = tf .train .Saver ()
242243 # Load a previous checkpoint if we find one
243244 latest_checkpoint = tf .train .latest_checkpoint (network_path )
244245 if latest_checkpoint :
245246 print ("Loading model checkpoint %s...\n " % latest_checkpoint )
246247 saver .restore (sess , latest_checkpoint )
247-
248+
248249 total_t = sess .run (tf .train .get_global_step ())
249250
250251 for episode in tqdm (range (n_episodes )):
@@ -254,7 +255,7 @@ def setup_summary():
254255
255256 frame = env .reset ()
256257 state = get_initial_state (frame )
257-
258+
258259 total_reward = 0
259260 total_loss = 0
260261 total_q_max = 0
@@ -266,30 +267,30 @@ def setup_summary():
266267
267268 action = q_estimator .get_action (sess , state )
268269 frame , reward , terminal , _ = env .step (action )
269-
270+
270271 processed_frame = preprocess (frame )
271272 next_state = np .append (state [1 :, :, :], processed_frame , axis = 0 )
272-
273+
273274 reward = np .clip (reward , - 1 , 1 )
274275 replay_memory .append (Transition (state , action , reward , next_state , terminal ))
275276 if len (replay_memory ) > replay_memory_size :
276277 replay_memory .popleft ()
277-
278+
278279 samples = random .sample (replay_memory , batch_size )
279280 states_batch , action_batch , reward_batch , next_states_batch , done_batch = map (np .array , zip (* samples ))
280-
281+
281282 # Calculate q values and targets (Double DQN)
282283 adapted_state = adapt_batch_state (next_states_batch )
283-
284+
284285 q_values_next = q_estimator .predict (sess , adapted_state )
285286 best_actions = np .argmax (q_values_next , axis = 1 )
286287 q_values_next_target = target_estimator .predict (sess , adapted_state )
287288 targets_batch = reward_batch + np .invert (done_batch ).astype (np .float32 ) * gamma * q_values_next_target [np .arange (batch_size ), best_actions ]
288-
289+
289290 # Perform gradient descent update
290291 states_batch = adapt_batch_state (states_batch )
291292 loss = q_estimator .update (sess , states_batch , action_batch , targets_batch )
292-
293+
293294 total_q_max += np .max (q_values_next )
294295 total_loss += loss
295296 total_t += 1
0 commit comments