View source on GitHub
|
A single-replica view of training procedure.
Inherits From: Task
tfm.nlp.tasks.TranslationTask(
params: tfm.core.config_definitions.TaskConfig,
logging_dir=None,
name=None
)
Tasks provide artifacts for training/evalution procedures, including loading/iterating over Datasets, initializing the model, calculating the loss and customized metrics with reduction.
Args |
|---|
params
logging_dir
name
Attributes |
|---|
logging_dir
task_config
Methods
aggregate_logs
aggregate_logs(
state=None, step_outputs=None
)
Aggregates over logs returned from a validation step.
build_inputs
build_inputs(
params: tfm.core.config_definitions.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None
)
Returns a dataset.
build_losses
build_losses(
labels, model_outputs, aux_losses=None
) -> tf.Tensor
Standard interface to compute losses.
| Args |
|---|
labels
model_outputs
aux_losses
losses in keras.Model.
| Returns | |
|---|---|
| The total loss tensor. |
build_metrics
build_metrics(
training: bool = True
)
Gets streaming metrics for training/validation.
build_model
build_model() -> tf.keras.Model
Creates model architecture.
| Returns | |
|---|---|
| A model instance. |
create_optimizer
@classmethodcreate_optimizer( optimizer_config:tfm.optimization.OptimizationConfig, runtime_config: Optional[tfm.core.base_task.RuntimeConfig] = None, dp_config: Optional[tfm.core.base_task.DifferentialPrivacyConfig] = None )
Creates an TF optimizer from configurations.
| Args |
|---|
optimizer_config
runtime_config
dp_config
| Returns | |
|---|---|
| A tf.optimizers.Optimizer object. |
inference_step
inference_step(
inputs, model: tf.keras.Model
)
Performs the forward step.
With distribution strategies, this method runs on devices.
| Args |
|---|
inputs
model
| Returns | |
|---|---|
| Model outputs. |
initialize
initialize(
model: tf.keras.Model
)
[Optional] A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint is found for the model. If there is a checkpoint, the checkpoint will be loaded and this function will not be called. You can use this callback function to load a pretrained checkpoint, saved under a directory other than the model_dir.
| Args |
|---|
model
process_compiled_metrics
process_compiled_metrics(
compiled_metrics, labels, model_outputs
)
Process and update compiled_metrics.
call when using compile/fit API.
| Args |
|---|
compiled_metrics
labels
model_outputs
process_metrics
process_metrics(
metrics, labels, model_outputs, **kwargs
)
Process and update metrics.
Called when using custom training loop API.
| Args |
|---|
metrics
labels
model_outputs
**kwargs
reduce_aggregated_logs
reduce_aggregated_logs(
aggregated_logs, global_step=None
)
Optional reduce of aggregated logs over validation steps.
This function reduces aggregated logs at the end of validation, and can be used to compute the final metrics. It runs on CPU and in each eval_end() in base trainer (see eval_end() function in official/core/base_trainer.py).
| Args |
|---|
aggregated_logs
global_step
| Returns | |
|---|---|
| A dictionary of reduced results. |
train_step
train_step(
inputs,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics=None
)
Does forward and backward.
With distribution strategies, this method runs on devices.
| Args |
|---|
inputs
model
optimizer
metrics
| Returns | |
|---|---|
| A dictionary of logs. |
validation_step
validation_step(
inputs, model: tf.keras.Model, metrics=None
)
Validation step.
With distribution strategies, this method runs on devices.
| Args |
|---|
inputs
model
metrics
| Returns | |
|---|---|
| A dictionary of logs. |
Class Variables |
|---|
'loss'
View source on GitHub