View source on GitHub
|
Exposes a numpy API for saved_model policies in Eager mode.
Inherits From: PyTFEagerPolicyBase, PyPolicy
tf_agents.policies.SavedModelPyTFEagerPolicy(
model_path: Text,
time_step_spec: Optional[tf_agents.trajectories.TimeStep] = None,
action_spec: Optional[tf_agents.typing.types.DistributionSpecV2] = None,
policy_state_spec: tf_agents.typing.types.NestedTensorSpec = (),
info_spec: tf_agents.typing.types.NestedTensorSpec = (),
load_specs_from_pbtxt: bool = False,
use_tf_function: bool = False,
batch_time_steps=True
)
Used in the notebooks
| Used in the tutorials |
|---|
Args |
|---|
model_path
policy_saver.
time_step_spec
time_step_spec. This is not used by the
SavedModelPyTFEagerPolicy, but may be accessed by other objects as it is
part of the public policy API.
action_spec
ArraySpecs describing the
policy's action_spec. This is not used by the
SavedModelPyTFEagerPolicy, but may be accessed by other objects as it is
part of the public policy API.
policy_state_spec
ArraySpecs describing
the policy's policy_state_spec. This is not used by the
SavedModelPyTFEagerPolicy, but may be accessed by other objects as it is
part of the public policy API.
info_spec
ArraySpecs describing the
policy's info_spec. This is not used by the SavedModelPyTFEagerPolicy,
but may be accessed by other objects as it is part of the public policy
API.
load_specs_from_pbtxt
policy_saver.
use_tf_function
batch_time_steps
Attributes |
|---|
action_spec
action().action can be a single np.Array, or a nested dict, list or tuple of
np.Array.
collect_data_spec
info_spec
action().
observation_and_action_constraint_splitter
policy_state_spec
policy_state as input.
policy_step_spec
action().
time_step_spec
TimeStep np.Arrays expected by action(time_step).
trajectory_spec
Methods
action
action(
time_step: tf_agents.trajectories.TimeStep,
policy_state: tf_agents.typing.types.NestedArray = (),
seed: Optional[types.Seed] = None
) -> tf_agents.trajectories.PolicyStep
Generates next action given the time_step and policy_state.
| Args |
|---|
time_step
TimeStep tuple corresponding to time_step_spec().
policy_state
seed
| Returns | |
|---|---|
A PolicyStep named tuple containing:
action: A nest of action Arrays matching the action_spec().
state: A nest of policy states to be fed into the next call to action.
info: Optional side information such as action log probabilities.
|
get_initial_state
get_initial_state(
batch_size: Optional[int] = None
) -> tf_agents.typing.types.NestedArray
Returns an initial state usable by the policy.
| Args |
|---|
batch_size
| Returns | |
|---|---|
| An initial policy state. |
get_metadata
get_metadata()
Returns the metadata of the saved model.
get_train_step
get_train_step() -> tf_agents.typing.types.Int
Returns the training global step of the saved model.
get_train_step_from_last_restored_checkpoint_path
get_train_step_from_last_restored_checkpoint_path() -> Optional[int]
Returns the training step of the restored checkpoint.
update_from_checkpoint
update_from_checkpoint(
checkpoint_path: Text
)
Allows users to update saved_model variables directly from a checkpoint.
checkpoint_path is a path that was passed to either PolicySaver.save()
or PolicySaver.save_checkpoint(). The policy looks for set of checkpoint
files with the file prefix `
| Args |
|---|
checkpoint_path
variables
variables()
View source on GitHub