more typos
This commit is contained in:
parent
cc4da94646
commit
3b0032cb2f
@ -139,7 +139,7 @@ class Loss(torch.nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.loss(x, self.target)
|
return self.loss(x, self.target)
|
||||||
|
|
||||||
# Generated TorchScript modules and serialise
|
# Generate TorchScript modules and serialise them
|
||||||
torch.jit.script(Data()).save('data.pt')
|
torch.jit.script(Data()).save('data.pt')
|
||||||
torch.jit.script(Net()).save('net.pt')
|
torch.jit.script(Net()).save('net.pt')
|
||||||
torch.jit.script(Loss(y_train)).save('loss.pt')
|
torch.jit.script(Loss(y_train)).save('loss.pt')
|
||||||
@ -165,12 +165,12 @@ NoaFloat {
|
|||||||
|
|
||||||
// Set the model in training mode
|
// Set the model in training mode
|
||||||
netModule.train(true)
|
netModule.train(true)
|
||||||
// Loss function set for training regime
|
// Loss function for training
|
||||||
lossModule.setBuffer("target", yTrain)
|
lossModule.setBuffer("target", yTrain)
|
||||||
|
|
||||||
// Compute the predictions
|
// Compute the predictions
|
||||||
val yPred = netModule.forward(xTrain)
|
val yPred = netModule.forward(xTrain)
|
||||||
// Compute the training error
|
// Compute the training loss
|
||||||
val loss = lossModule.forward(yPred)
|
val loss = lossModule.forward(yPred)
|
||||||
println(loss)
|
println(loss)
|
||||||
|
|
||||||
@ -195,7 +195,7 @@ NoaFloat {
|
|||||||
// Finally validate the model
|
// Finally validate the model
|
||||||
// Compute the predictions for the validation features
|
// Compute the predictions for the validation features
|
||||||
netModule.forwardAssign(xVal, yPred)
|
netModule.forwardAssign(xVal, yPred)
|
||||||
// Set the loss for the true values
|
// Set the loss for validation
|
||||||
lossModule.setBuffer("target", yVal)
|
lossModule.setBuffer("target", yVal)
|
||||||
// Compute the loss on validation dataset
|
// Compute the loss on validation dataset
|
||||||
lossModule.forwardAssign(yPred, loss)
|
lossModule.forwardAssign(yPred, loss)
|
||||||
|
Loading…
Reference in New Issue
Block a user