module
Base class for all neural network modules.
commen usage
1 | import torch.nn as nn |
to()
example:
1 | model = network() |
this will moves the model to the device, usually on GPU.
model.eval() & with torch.no_grad()
these have different goals:
- model.eval() will notify all your layers that you are in eval mode, that way, batchnorm or dropout will not work in the eval mode.
- torch.no_grad() impacts the autograd engine and deactivate it.
- the commen practice for evaluating/validation is using
with torch.no_grad()in pair withmodel.eval():and after that don’t forget to turn back to training mode using1
2
3
4
5
6
7
8
9model.eval()
with torch.no_grad():
...
output = model(data)
...
#training step
model.train()model.train()