tflearn.helpers.trainer.Trainer (train_ops, graph=None, clip_gradients=5.0, tensorboard_dir='/tmp/tflearn_logs/', tensorboard_verbose=0, checkpoint_path=None, best_checkpoint_path=None, max_checkpoints=None, keep_checkpoint_every_n_hours=10000.0, random_seed=None, session=None, best_val_accuracy=0.0)

Generic class to handle any TensorFlow graph training. It requires the use of TrainOp to specify all optimization parameters.


  • train_ops: list of TrainOp. A list of a network training operations for performing optimizations.
  • graph: tf.Graph. The TensorFlow graph to use. Default: default tf graph.
  • clip_gradients: float. Clip gradient. Default: 5.0.
  • tensorboard_dir: str. Tensorboard log directory. Default: "/tmp/tflearn_logs/".
  • tensorboard_verbose: int. Verbose level. It supports:
0 - Loss, Accuracy. (Best Speed)
1 - Loss, Accuracy, Gradients.
2 - Loss, Accuracy, Gradients, Weights.
3 - Loss, Accuracy, Gradients, Weights, Activations, Sparsity.(Best Visualization)
  • checkpoint_path: str. Path to store model checkpoints. If None, no model checkpoint will be saved. Default: None.
  • best_checkpoint_path: str. Path to store the model when the validation rate reaches its highest point of the current training session and also is above best_val_accuracy. Default: None.
  • max_checkpoints: int or None. Maximum amount of checkpoints. If None, no limit. Default: None.
  • keep_checkpoint_every_n_hours: float. Number of hours between each model checkpoints.
  • random_seed: int. Random seed, for test reproductivity. Default: None.
  • session: Session. A session for running ops. If None, a new one will be created. Note: When providing a session, variables must have been initialized already, otherwise an error will be raised.
  • best_val_accuracy: float The minimum validation accuracy that needs to be achieved before a model weight's are saved to the best_checkpoint_path. This allows the user to skip early saves and also set a minimum save point when continuing to train a reloaded model. Default: 0.0.


fit (feed_dicts, n_epoch=10, val_feed_dicts=None, show_metric=False, snapshot_step=None, snapshot_epoch=True, shuffle_all=None, dprep_dict=None, daug_dict=None, excl_trainops=None, run_id=None, callbacks=[])

Train network with feeded data dicts.

# 1 Optimizer{input1: X, output1: Y},val_feed_dicts={input1: X, output1: Y}){input1: X1, input2: X2, output1: Y},val_feed_dicts=0.1) # 10% of data used for validation

# 2 Optimizers[{in1: X1, out1:Y}, {in2: X2, out2:Y2}],val_feed_dicts=[{in1: X1, out1:Y}, {in2: X2, out2:Y2}])
  • feed_dicts: dict or list of dict. The dictionary to feed data to the network. It follows Tensorflow feed dict specifications: '{placeholder: data}'. In case of multiple optimizers, a list of dict is expected, that will respectively feed optimizers.
  • n_epoch: int. Number of epoch to runs.
  • val_feed_dicts: dict, list of dict, float or list of float. The data used for validation. Feed dict are following the same specification as feed_dicts above. It is also possible to provide a float for splitting training data for validation (Note that this will shuffle data).
  • show_metric: bool. If True, accuracy will be calculated and displayed at every step. Might give slower training.
  • snapshot_step: int. If not None, the network will be snapshot every provided step (calculate validation loss/accuracy and save model, if a checkpoint_path is specified in Trainer).
  • snapshot_epoch: bool. If True, snapshot the network at the end of every epoch.
  • shuffle_all: bool. If True, shuffle all data batches (overrides TrainOp shuffle parameter behavior).
  • dprep_dict: dict with Placeholder as key and DataPreprocessing as value. Apply realtime data preprocessing to the given placeholders (Applied at training and testing time).
  • daug_dict: dict with Placeholder as key and DataAugmentation as value. Apply realtime data augmentation to the given placeholders (Only applied at training time).
  • excl_trainops: list of TrainOp. A list of train ops to exclude from training process.
  • run_id: str. A name for the current run. Used for Tensorboard display. If no name provided, a random one will be generated.
  • callbacks: Callback or list. Custom callbacks to use in the training life cycle

restore (model_file, trainable_variable_only=False, variable_name_map=None, scope_for_restore=None, create_new_session=True, verbose=False)

Restore a Tensorflow model

  • model_file: path of tensorflow model to restore
  • trainable_variable_only: If True, only restore trainable variables.
  • variable_name_map: - a (pattern, repl) tuple providing a regular expression pattern and replacement, which is applied to variable names, before restoration from the model file -- OR, a function map_func, used to perform the mapping, called as: name_in_file = map_func(existing_var_op_name) The function may return None to indicate a variable is not to be restored.
  • scope_for_restore: string specifying the scope to limit to, when restoring variables. -Also removes the scope name prefix from the var name to use when restoring.
  • create_new_session: Set to False if the current session is to be kept. Set to True (the default) to create a new session, and re-init all variables.
  • verbose : Set to True to see a printout of what variables are being restored,when using scope_for_restore or variable_name_map

save (model_file, global_step=None)

Save a Tensorflow model

  • model_file: str. Saving path of tensorflow model
  • global_step: int. The training step to append to the model file name (optional).


tflearn.helpers.trainer.TrainOp (loss, optimizer, metric=None, batch_size=64, ema=0.0, trainable_vars=None, shuffle=True, step_tensor=None, validation_monitors=None, validation_batch_size=None, name=None, graph=None)

TrainOp represents a set of operation used for optimizing a network.

A TrainOp is meant to hold all training parameters of an optimizer. Trainer class will then instantiate them all specifically considering all optimizers of the network (set names, scopes... set optimization ops...).


  • loss: Tensor. Loss operation to evaluate network cost. Optimizer will use this cost function to train network.
  • optimizer: Optimizer. Tensorflow Optimizer. The optimizer to use to train network.
  • metric: Tensor. The metric tensor to be used for evaluation.
  • batch_size: int. Batch size for data feeded to this optimizer. Default: 64.
  • ema: float. Exponential moving averages.
  • trainable_vars: list of tf.Variable. List of trainable variables to use for training. Default: all trainable variables.
  • shuffle: bool. Shuffle data.
  • step_tensor: tf.Tensor. A variable holding training step. If not provided, it will be created. Early defining the step tensor might be useful for network creation, such as for learning rate decay.
  • validation_monitors: list of Tensor objects. List of variables to compute during validation, which are also used to produce summaries for output to TensorBoard. For example, this can be used to periodically record a confusion matrix or AUC metric, during training. Each variable should have rank 1, i.e. shape [None].
  • validation_batch_size: int or None. If int, specifies the batch size to be used for the validation data feed; otherwise defaults to being th esame as batch_size.
  • name: str. A name for this class (optional).
  • graph: tf.Graph. Tensorflow Graph to use for training. Default: default tf graph.


initialize_fit (feed_dict, val_feed_dict, dprep_dict, daug_dict, show_metric, summ_writer, coord)

Initialize data for feeding the training process. It is meant to be used by Trainer before starting to fit data.

  • feed_dict: dict. The data dictionary to feed.
  • val_feed_dict: dict or float. The validation data dictionary to feed or validation split.
  • dprep_dict: dict. Data Preprocessing dict (with placeholder as key and corresponding DataPreprocessing object as value).
  • daug_dict: dict. Data Augmentation dict (with placeholder as key and corresponding DataAugmentation object as value).
  • show_metric: bool. If True, display accuracy at every step.
  • summ_writer: SummaryWriter. The summary writer to use for Tensorboard logging.

initialize_training_ops (i, session, tensorboard_verbose, clip_gradients)

Initialize all ops used for training. Because a network can have multiple optimizers, an id 'i' is allocated to differentiate them. This is meant to be used by Trainer when initializing all train ops.

  • i: int. This optimizer training process ID.
  • session: tf.Session. The session used to train the network.
  • tensorboard_verbose: int. Logs verbose. Supports:
0 - Loss, Accuracy.
1 - Loss, Accuracy, Gradients.
2 - Loss, Accuracy, Gradients, Weights.
3 - Loss, Accuracy, Gradients, Weights, Activations, Sparsity..
  • clip_gradients: float. Option for clipping gradients.