more typos
This commit is contained in:
parent
cc4da94646
commit
3b0032cb2f
@ -139,7 +139,7 @@ class Loss(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return self.loss(x, self.target)
|
||||
|
||||
# Generated TorchScript modules and serialise
|
||||
# Generate TorchScript modules and serialise them
|
||||
torch.jit.script(Data()).save('data.pt')
|
||||
torch.jit.script(Net()).save('net.pt')
|
||||
torch.jit.script(Loss(y_train)).save('loss.pt')
|
||||
@ -165,12 +165,12 @@ NoaFloat {
|
||||
|
||||
// Set the model in training mode
|
||||
netModule.train(true)
|
||||
// Loss function set for training regime
|
||||
// Loss function for training
|
||||
lossModule.setBuffer("target", yTrain)
|
||||
|
||||
// Compute the predictions
|
||||
val yPred = netModule.forward(xTrain)
|
||||
// Compute the training error
|
||||
// Compute the training loss
|
||||
val loss = lossModule.forward(yPred)
|
||||
println(loss)
|
||||
|
||||
@ -195,7 +195,7 @@ NoaFloat {
|
||||
// Finally validate the model
|
||||
// Compute the predictions for the validation features
|
||||
netModule.forwardAssign(xVal, yPred)
|
||||
// Set the loss for the true values
|
||||
// Set the loss for validation
|
||||
lossModule.setBuffer("target", yVal)
|
||||
// Compute the loss on validation dataset
|
||||
lossModule.forwardAssign(yPred, loss)
|
||||
|
Loading…
Reference in New Issue
Block a user