Deep Neural Network Model
tflearn.models.dnn.DNN (network, clip_gradients=5.0, tensorboard_verbose=0, tensorboard_dir='/tmp/tflearn_logs/', checkpoint_path=None, best_checkpoint_path=None, max_checkpoints=None, session=None, best_val_accuracy=0.0)
TODO: complete description
Arguments
- network:
Tensor
. Neural network to be used. - tensorboard_verbose:
int
. Summary verbose level, it accepts different levels of tensorboard logs:
0: Loss, Accuracy (Best Speed).
1: Loss, Accuracy, Gradients.
2: Loss, Accuracy, Gradients, Weights.
3: Loss, Accuracy, Gradients, Weights, Activations, Sparsity.(Best visualization)
- tensorboard_dir:
str
. Directory to store tensorboard logs. Default: "/tmp/tflearn_logs/" - checkpoint_path:
str
. Path to store model checkpoints. If None, no model checkpoint will be saved. Default: None. - best_checkpoint_path:
str
. Path to store the model when the validation rate reaches its highest point of the current training session and also is above best_val_accuracy. Default: None. - max_checkpoints:
int
or None. Maximum amount of checkpoints. If None, no limit. Default: None. - session:
Session
. A session for running ops. If None, a new one will be created. Note: When providing a session, variables must have been initialized already, otherwise an error will be raised. - best_val_accuracy:
float
The minimum validation accuracy that needs to be achieved before a model weight's are saved to the best_checkpoint_path. This allows the user to skip early saves and also set a minimum save point when continuing to train a reloaded model. Default: 0.0.
Attributes
- trainer:
Trainer
. Handle model training. - predictor:
Predictor
. Handle model prediction. - session:
Session
. The current model session.
Methods
evaluate (X, Y, batch_size=128)
Evaluate model metric(s) on given samples.
Arguments
- X: array,
list
of array (if multiple inputs) ordict
(with inputs layer name as keys). Data to feed to train model. - Y: array,
list
of array (if multiple inputs) ordict
(with estimators layer name as keys). Targets (Labels) to feed to train model. Usually set as the next element of a sequence, i.e. for x[0] => y[0] = x[1]. - batch_size:
int
. The batch size. Default: 128.
Returns
The metric(s) score.
fit (X_inputs, Y_targets, n_epoch=10, validation_set=None, show_metric=False, batch_size=None, shuffle=None, snapshot_epoch=True, snapshot_step=None, excl_trainops=None, validation_batch_size=None, run_id=None, callbacks=[])
Train model, feeding X_inputs and Y_targets to the network.
NOTE: When not feeding dicts, data assignations is made by input/estimator layers creation order (For example, the second input layer created will be feeded by the second value of X_inputs list).
Examples
model.fit(X, Y) # Single input and output
model.fit({'input1': X}, {'output1': Y}) # Single input and output
model.fit([X1, X2], Y) # Mutliple inputs, Single output
# validate with X_val and [Y1_val, Y2_val]
model.fit(X, [Y1, Y2], validation_set=(X_val, [Y1_val, Y2_val]))
# 10% of training data used for validation
model.fit(X, Y, validation_set=0.1)
Arguments
- X_inputs: array,
list
of array (if multiple inputs) ordict
(with inputs layer name as keys). Data to feed to train model. - Y_targets: array,
list
of array (if multiple inputs) ordict
(with estimators layer name as keys). Targets (Labels) to feed to train model. - n_epoch:
int
. Number of epoch to run. Default: None. - validation_set:
tuple
. Represents data used for validation.tuple
holds data and targets (provided as same type as X_inputs and Y_targets). Additionally, it also acceptsfloat
(<1) to performs a data split over training data. - show_metric:
bool
. Display or not accuracy at every step. - batch_size:
int
or None. Ifint
, overrides all network estimators 'batch_size' by this value. Also overridesvalidation_batch_size
ifint
, and ifvalidation_batch_size
is None. - validation_batch_size:
int
or None. Ifint
, overrides all network estimators 'validation_batch_size' by this value. - shuffle:
bool
or None. Ifbool
, overrides all network estimators 'shuffle' by this value. - snapshot_epoch:
bool
. If True, it will snapshot model at the end of every epoch. (Snapshot a model will evaluate this model on validation set, as well as create a checkpoint if 'checkpoint_path' specified). - snapshot_step:
int
or None. Ifint
, it will snapshot model every 'snapshot_step' steps. - excl_trainops:
list
ofTrainOp
. A list of train ops to exclude from training process (TrainOps can be retrieve throughtf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)
). - run_id:
str
. Give a name for this run. (Useful for Tensorboard). - callbacks:
Callback
orlist
. Custom callbacks to use in the training life cycle
get_weights (weight_tensor)
Get a variable weights.
Examples
dnn = DNNTrainer(...)
w = dnn.get_weights(denselayer.W) # get a dense layer weights
w = dnn.get_weights(convlayer.b) # get a conv layer biases
Arguments
- weight_tensor:
Tensor
. A Variable.
Returns
np.array
. The provided variable weights.
load (model_file, weights_only=False, **optargs)
Restore model weights.
Arguments
- model_file:
str
. Model path. - weights_only:
bool
. If True, only weights will be restored ( and not intermediate variable, such as step counter, moving averages...). Note that if you are using batch normalization, averages will not be restored as well. - optargs: optional extra arguments for trainer.restore (see helpers/trainer.py) These optional arguments may be used to limit the scope of variables restored, and to control whether a new session is created for the restored variables.
predict (X)
Model prediction for given input data.
Arguments
- X: array,
list
of array (if multiple inputs) ordict
(with inputs layer name as keys). Data to feed for prediction.
Returns
array or list
of array. The predicted probabilities.
predict_label (X)
Predict class labels for input X.
Arguments
- X: array,
list
of array (if multiple inputs) ordict
(with inputs layer name as keys). Data to feed for prediction.
Returns
array or list
of array. The predicted classes index array, sorted
by descendant probability value.
save (model_file)
Save model weights.
Arguments
- model_file:
str
. Model path.
set_weights (tensor, weights)
Assign a tensor variable a given value.
Arguments
- tensor:
Tensor
. The tensor variable to assign value. - weights: The value to be assigned.