tf_agents.utils.common.Checkpointer is a utility to save/load the training state, policy state, and replay_buffer state to/from a local storage.
tf_agents.policies.policy_saver.PolicySaver is a tool to save/load only the policy, and is lighter than Checkpointer. You can use PolicySaver to deploy the model as well without any knowledge of the code that created the policy.
In this tutorial, we will use DQN to train a model, then use Checkpointer and PolicySaver to show how we can store and load the states and model in an interactive way. Note that we will use TF2.0's new saved_model tooling and format for PolicySaver.
Setup
If you haven't installed the following dependencies, run:
2023-12-22 12:17:45.769390: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-22 12:17:45.769434: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-22 12:17:45.771084: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
# Set up a virtual display for rendering OpenAI gym environments.importxvfbwrapperxvfbwrapper.Xvfb(1400,900,24).start()
DQN agent
We are going to set up DQN agent, just like in the previous colab. The details are hidden by default as they are not core part of this colab, but you can click on 'SHOW CODE' to see the details.
replay_buffer=tf_uniform_replay_buffer.TFUniformReplayBuffer(data_spec=agent.collect_data_spec,batch_size=train_env.batch_size,max_length=replay_buffer_capacity)collect_driver=dynamic_step_driver.DynamicStepDriver(train_env,agent.collect_policy,observers=[replay_buffer.add_batch],num_steps=collect_steps_per_iteration)# Initial data collectioncollect_driver.run()# Dataset generates trajectories with shape [BxTx...] where# T = n_step_update + 1.dataset=replay_buffer.as_dataset(num_parallel_calls=3,sample_batch_size=batch_size,num_steps=2).prefetch(3)iterator=iter(dataset)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py:377: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.
Train the agent
# (Optional) Optimize by wrapping some of the code in a graph using TF function.agent.train=common.function(agent.train)deftrain_one_iteration():# Collect a few steps using collect_policy and save to the replay buffer.collect_driver.run()# Sample a batch of data from the buffer and update the agent's network.experience,unused_info=next(iterator)train_loss=agent.train(experience)iteration=agent.train_step_counter.numpy()print('iteration: {0} loss: {1}'.format(iteration,train_loss.loss))
Video Generation
defembed_gif(gif_buffer):"""Embeds a gif file in the notebook."""tag='<img src="data:image/gif;base64,{0}"/>'.format(base64.b64encode(gif_buffer).decode())returnIPython.display.HTML(tag)defrun_episodes_and_create_video(policy,eval_tf_env,eval_py_env):num_episodes=3frames=[]for_inrange(num_episodes):time_step=eval_tf_env.reset()frames.append(eval_py_env.render())whilenottime_step.is_last():action_step=policy.action(time_step)time_step=eval_tf_env.step(action_step.action)frames.append(eval_py_env.render())gif_file=io.BytesIO()imageio.mimsave(gif_file,frames,format='gif',fps=60)IPython.display.display(embed_gif(gif_file.getvalue()))
Generate a video
Check the performance of the policy by generating a video.
print('Training one iteration....')train_one_iteration()
Training one iteration....
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:1260: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:1260: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
iteration: 1 loss: 0.936349093914032
Save to checkpoint
train_checkpointer.save(global_step)
Restore checkpoint
For this to work, the whole set of objects should be recreated the same way as when the checkpoint was created.
WARNING:absl:`0/step_type` is not a valid tf.function parameter name. Sanitizing to `arg_0_step_type`.
WARNING:absl:`0/reward` is not a valid tf.function parameter name. Sanitizing to `arg_0_reward`.
WARNING:absl:`0/discount` is not a valid tf.function parameter name. Sanitizing to `arg_0_discount`.
WARNING:absl:`0/observation` is not a valid tf.function parameter name. Sanitizing to `arg_0_observation`.
WARNING:absl:`0/step_type` is not a valid tf.function parameter name. Sanitizing to `arg_0_step_type`.
INFO:tensorflow:Assets written to: /tmpfs/tmp/policy/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:458: UserWarning: Encoding a StructuredValue with type tfp.distributions.Deterministic_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.
warnings.warn("Encoding a StructuredValue with type %s; loading this "
INFO:tensorflow:Assets written to: /tmpfs/tmp/policy/assets
The policy can be loaded without having any knowledge of what agent or network was used to create it. This makes deployment of the policy much easier.
The rest of the colab will help you export / import checkpointer and policy directories such that you can continue training at a later point and deploy the model without having to train again.
Now you can go back to 'Train one iteration' and train a few more times such that you can understand the difference later on. Once you start to see slightly better results, continue below.
Create zip file and upload zip file (double-click to see the code)
defcreate_zip_file(dirname,base_filename):returnshutil.make_archive(base_filename,'zip',dirname)defupload_and_unzip_file_to(dirname):iffilesisNone:returnuploaded=files.upload()forfninuploaded.keys():print('User uploaded file "{name}" with length {length} bytes'.format(name=fn,length=len(uploaded[fn])))shutil.rmtree(dirname)zip_files=zipfile.ZipFile(io.BytesIO(uploaded[fn]),'r')zip_files.extractall(dirname)zip_files.close()
Create a zipped file from the checkpoint directory.
iffilesisnotNone:files.download(checkpoint_zip_filename)# try again if this fails: https://github.com/googlecolab/colabtools/issues/469
After training for some time (10-15 times), download the checkpoint zip file,
and go to "Runtime > Restart and run all" to reset the training,
and come back to this cell. Now you can upload the downloaded zip file,
and continue the training.
Once you have uploaded checkpoint directory, go back to 'Train one iteration' to continue training or go back to 'Generate a video' to check the performance of the loaded policy.
Alternatively, you can save the policy (model) and restore it.
Unlike checkpointer, you cannot continue with the training, but you can still deploy the model. Note that the downloaded file is much smaller than that of the checkpointer.
INFO:tensorflow:Assets written to: /tmpfs/tmp/policy/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:458: UserWarning: Encoding a StructuredValue with type tfp.distributions.Deterministic_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.
warnings.warn("Encoding a StructuredValue with type %s; loading this "
INFO:tensorflow:Assets written to: /tmpfs/tmp/policy/assets
iffilesisnotNone:files.download(policy_zip_filename)# try again if this fails: https://github.com/googlecolab/colabtools/issues/469
Upload the downloaded policy directory (exported_policy.zip) and check how the saved policy performs.
Note that this only works when eager mode is enabled.
eager_py_policy=py_tf_eager_policy.SavedModelPyTFEagerPolicy(policy_dir,eval_py_env.time_step_spec(),eval_py_env.action_spec())# Note that we're passing eval_py_env not eval_env.run_episodes_and_create_video(eager_py_policy,eval_py_env,eval_py_env)
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-12-22 UTC."],[],[]]