diff --git a/kmath-noa/README.md b/kmath-noa/README.md index 7bc3f6db4..b4c255ea2 100644 --- a/kmath-noa/README.md +++ b/kmath-noa/README.md @@ -139,7 +139,7 @@ class Loss(torch.nn.Module): def forward(self, x): return self.loss(x, self.target) -# Generated TorchScript modules and serialise +# Generate TorchScript modules and serialise them torch.jit.script(Data()).save('data.pt') torch.jit.script(Net()).save('net.pt') torch.jit.script(Loss(y_train)).save('loss.pt') @@ -165,12 +165,12 @@ NoaFloat { // Set the model in training mode netModule.train(true) - // Loss function set for training regime + // Loss function for training lossModule.setBuffer("target", yTrain) // Compute the predictions val yPred = netModule.forward(xTrain) - // Compute the training error + // Compute the training loss val loss = lossModule.forward(yPred) println(loss) @@ -195,7 +195,7 @@ NoaFloat { // Finally validate the model // Compute the predictions for the validation features netModule.forwardAssign(xVal, yPred) - // Set the loss for the true values + // Set the loss for validation lossModule.setBuffer("target", yVal) // Compute the loss on validation dataset lossModule.forwardAssign(yPred, loss)