*tuple and **dictionary
from /mine_v2/training/BaseTraining.pyself.loss = ug.get_class(self.conf.get_string('network.loss.loss_type'))(**self.conf.get_config('network.loss.properties'))
in a function call
*t means “treat the elements of this iterable as positional arguments to this function call.”
1 | def f(x,y): |
**d means “treat the key-value pairs in the dictionary as additional named arguments to this function call.”
1 | def f(x,y): |
in a function signature
*t means “take all additional positional arguments to this function and pack them into this parameter as a tuple.”
1 | def f(*t): |
**d means “take all additional named arguments to this function and insert them into this parameter as dictionary entries.”
1 | def f(**d): |
torch.save() & torch.load()
save
step1. define a dictionarty and the path
1 | state = {'epoch': model.epoch, |
step2. use the torch.save()
1 | torch.save(state, dir) |
load
if you want to continue the training runner from the checkpoint
1 | checkpoint = torch.load(dir) |
model.train()
with torch.no_grad()
There is a question that if we disable gradient calculation in the plot module:
1 | if network.epoch == plot.epoch: |
then we use .train(), will it restore the funtion of gradient calculation?
and we can call either network.eval() or network.train(mode=False) to tell that we are testing.
Besides, we can also check the network.training flag, it is False when in eval mode. And currently only Dropout and BatchNorm care about that flag. (https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)