Trainer
Trainer for training.
- class mindnlp.engine.trainer.Trainer(network=None, train_dataset=None, eval_dataset=None, metrics=None, epochs=10, loss_fn=None, optimizer=None, callbacks=None, jit=False)[source]
Bases:
objectTrainer to train the model.
- Parameters
network (Cell) – A training network.
train_dataset (Dataset) – A training dataset iterator. If loss_fn is defined, the data and label will be passed to the network and the loss_fn respectively, so a tuple (data, label) should be returned from dataset. If there is multiple data or labels, set loss_fn to None and implement calculation of loss in network, then a tuple (data1, data2, data3, …) with all data returned from dataset will be passed to the network.
eval_dataset (Dataset) – A evaluating dataset iterator. If loss_fn is defined, the data and label will be passed to the network and the loss_fn respectively, so a tuple (data, label) should be returned from dataset. If there is multiple data or labels, set loss_fn to None and implement calculation of loss in network, then a tuple (data1, data2, data3, …) with all data returned from dataset will be passed to the network.
metrics (Optional[list[Metrics], Metrics]) – List of metrics objects which should be used while evaluating. Default:None.
epochs (int) – Total number of iterations on the data. Default: 10.
optimizer (Cell) – Optimizer for updating the weights. If optimizer is None, the network needs to do backpropagation and update weights. Default value: None.
loss_fn (Cell) – Objective function. If loss_fn is None, the network should contain the calculation of loss and parallel if needed. Default: None.
callbacks (Optional[list[Callback], Callback]) – List of callback objects which should be executed while training. Default: None.
jit (bool) – Whether use Just-In-Time compile.