Trainer
tflearn.helpers.trainer.Trainer (train_ops, graph=None, clip_gradients=5.0, tensorboard_dir='/tmp/tflearn_logs/', tensorboard_verbose=0, checkpoint_path=None, best_checkpoint_path=None, max_checkpoints=None, keep_checkpoint_every_n_hours=10000.0, random_seed=None, session=None, best_val_accuracy=0.0)
Generic class to handle any TensorFlow graph training. It requires
the use of TrainOp
to specify all optimization parameters.
Arguments
- train_ops: list of
TrainOp
. A list of a network training operations for performing optimizations. - graph:
tf.Graph
. The TensorFlow graph to use. Default: default tf graph. - clip_gradients:
float
. Clip gradient. Default: 5.0. - tensorboard_dir:
str
. Tensorboard log directory. Default: "/tmp/tflearn_logs/". - tensorboard_verbose:
int
. Verbose level. It supports:
0 - Loss, Accuracy. (Best Speed)
1 - Loss, Accuracy, Gradients.
2 - Loss, Accuracy, Gradients, Weights.
3 - Loss, Accuracy, Gradients, Weights, Activations, Sparsity.(Best Visualization)
- checkpoint_path:
str
. Path to store model checkpoints. If None, no model checkpoint will be saved. Default: None. - best_checkpoint_path:
str
. Path to store the model when the validation rate reaches its highest point of the current training session and also is above best_val_accuracy. Default: None. - max_checkpoints:
int
or None. Maximum amount of checkpoints. If None, no limit. Default: None. - keep_checkpoint_every_n_hours:
float
. Number of hours between each model checkpoints. - random_seed:
int
. Random seed, for test reproductivity. Default: None. - session:
Session
. A session for running ops. If None, a new one will be created. Note: When providing a session, variables must have been initialized already, otherwise an error will be raised. - best_val_accuracy:
float
The minimum validation accuracy that needs to be achieved before a model weight's are saved to the best_checkpoint_path. This allows the user to skip early saves and also set a minimum save point when continuing to train a reloaded model. Default: 0.0.
Methods
fit (feed_dicts, n_epoch=10, val_feed_dicts=None, show_metric=False, snapshot_step=None, snapshot_epoch=True, shuffle_all=None, dprep_dict=None, daug_dict=None, excl_trainops=None, run_id=None, callbacks=[])
Train network with feeded data dicts.
Examples
# 1 Optimizer
trainer.fit(feed_dicts={input1: X, output1: Y},val_feed_dicts={input1: X, output1: Y})
trainer.fit(feed_dicts={input1: X1, input2: X2, output1: Y},val_feed_dicts=0.1) # 10% of data used for validation
# 2 Optimizers
trainer.fit(feed_dicts=[{in1: X1, out1:Y}, {in2: X2, out2:Y2}],val_feed_dicts=[{in1: X1, out1:Y}, {in2: X2, out2:Y2}])
Arguments
- feed_dicts:
dict
or list ofdict
. The dictionary to feed data to the network. It follows Tensorflow feed dict specifications: '{placeholder: data}'. In case of multiple optimizers, a list of dict is expected, that will respectively feed optimizers. - n_epoch:
int
. Number of epoch to runs. - val_feed_dicts:
dict
, list ofdict
,float
or list offloat
. The data used for validation. Feed dict are following the same specification asfeed_dicts
above. It is also possible to provide afloat
for splitting training data for validation (Note that this will shuffle data). - show_metric:
bool
. If True, accuracy will be calculated and displayed at every step. Might give slower training. - snapshot_step:
int
. If not None, the network will be snapshot every provided step (calculate validation loss/accuracy and save model, if acheckpoint_path
is specified inTrainer
). - snapshot_epoch:
bool
. If True, snapshot the network at the end of every epoch. - shuffle_all:
bool
. If True, shuffle all data batches (overridesTrainOp
shuffle parameter behavior). - dprep_dict:
dict
withPlaceholder
as key andDataPreprocessing
as value. Apply realtime data preprocessing to the given placeholders (Applied at training and testing time). - daug_dict:
dict
withPlaceholder
as key andDataAugmentation
as value. Apply realtime data augmentation to the given placeholders (Only applied at training time). - excl_trainops:
list
ofTrainOp
. A list of train ops to exclude from training process. - run_id:
str
. A name for the current run. Used for Tensorboard display. If no name provided, a random one will be generated. - callbacks:
Callback
orlist
. Custom callbacks to use in the training life cycle
fit_batch (feed_dicts, dprep_dict=None, daug_dict=None)
Train network with a single batch.
Arguments
- feed_dicts:
dict
or list ofdict
. The dictionary to feed data to the network. It follows Tensorflow feed dict specifications: '{placeholder: data}'. In case of multiple optimizers, a list of dict is expected, that will respectively feed optimizers. - dprep_dict:
dict
withPlaceholder
as key andDataPreprocessing
as value. Apply realtime data preprocessing to the given placeholders (Applied at training and testing time). - daug_dict:
dict
withPlaceholder
as key andDataAugmentation
as value. Apply realtime data augmentation to the given placeholders (Only applied at training time).
restore (model_file, trainable_variable_only=False, variable_name_map=None, scope_for_restore=None, create_new_session=True, verbose=False)
Restore a Tensorflow model
Arguments
- model_file: path of tensorflow model to restore
- trainable_variable_only: If True, only restore trainable variables.
- variable_name_map: - a (pattern, repl) tuple providing a regular expression pattern and replacement, which is applied to variable names, before restoration from the model file -- OR, a function map_func, used to perform the mapping, called as: name_in_file = map_func(existing_var_op_name) The function may return None to indicate a variable is not to be restored.
- scope_for_restore: string specifying the scope to limit to, when restoring variables. -Also removes the scope name prefix from the var name to use when restoring.
- create_new_session: Set to False if the current session is to be kept. Set to True (the default) to create a new session, and re-init all variables.
- verbose : Set to True to see a printout of what variables are being restored,when using scope_for_restore or variable_name_map
save (model_file, global_step=None)
Save a Tensorflow model
Arguments
- model_file:
str
. Saving path of tensorflow model - global_step:
int
. The training step to append to the model file name (optional).
TrainOp
tflearn.helpers.trainer.TrainOp (loss, optimizer, metric=None, batch_size=64, ema=0.0, trainable_vars=None, shuffle=True, step_tensor=None, validation_monitors=None, validation_batch_size=None, name=None, graph=None)
TrainOp represents a set of operation used for optimizing a network.
A TrainOp is meant to hold all training parameters of an optimizer.
Trainer
class will then instantiate them all specifically considering all
optimizers of the network (set names, scopes... set optimization ops...).
Arguments
- loss:
Tensor
. Loss operation to evaluate network cost. Optimizer will use this cost function to train network. - optimizer:
Optimizer
. Tensorflow Optimizer. The optimizer to use to train network. - metric:
Tensor
. The metric tensor to be used for evaluation. - batch_size:
int
. Batch size for data feeded to this optimizer. Default: 64. - ema:
float
. Exponential moving averages. - trainable_vars: list of
tf.Variable
. List of trainable variables to use for training. Default: all trainable variables. - shuffle:
bool
. Shuffle data. - step_tensor:
tf.Tensor
. A variable holding training step. If not provided, it will be created. Early defining the step tensor might be useful for network creation, such as for learning rate decay. - validation_monitors:
list
ofTensor
objects. List of variables to compute during validation, which are also used to produce summaries for output to TensorBoard. For example, this can be used to periodically record a confusion matrix or AUC metric, during training. Each variable should have rank 1, i.e. shape [None]. - validation_batch_size:
int
or None. Ifint
, specifies the batch size to be used for the validation data feed; otherwise defaults to being th esame asbatch_size
. - name:
str
. A name for this class (optional). - graph:
tf.Graph
. Tensorflow Graph to use for training. Default: default tf graph.
Methods
initialize_fit (feed_dict, val_feed_dict, dprep_dict, daug_dict, show_metric, summ_writer, coord)
Initialize data for feeding the training process. It is meant to
be used by Trainer
before starting to fit data.
Arguments
- feed_dict:
dict
. The data dictionary to feed. - val_feed_dict:
dict
orfloat
. The validation data dictionary to feed or validation split. - dprep_dict:
dict
. Data Preprocessing dict (with placeholder as key and correspondingDataPreprocessing
object as value). - daug_dict:
dict
. Data Augmentation dict (with placeholder as key and correspondingDataAugmentation
object as value). - show_metric:
bool
. If True, display accuracy at every step. - summ_writer:
SummaryWriter
. The summary writer to use for Tensorboard logging.
initialize_training_ops (i, session, tensorboard_verbose, clip_gradients)
Initialize all ops used for training. Because a network can have
multiple optimizers, an id 'i' is allocated to differentiate them.
This is meant to be used by Trainer
when initializing all train ops.
Arguments
- i:
int
. This optimizer training process ID. - session:
tf.Session
. The session used to train the network. - tensorboard_verbose:
int
. Logs verbose. Supports:
0 - Loss, Accuracy.
1 - Loss, Accuracy, Gradients.
2 - Loss, Accuracy, Gradients, Weights.
3 - Loss, Accuracy, Gradients, Weights, Activations, Sparsity..
- clip_gradients:
float
. Option for clipping gradients.