forked from kscience/kmath
Batched operations in algebra
This commit is contained in:
parent
ca3cca65ef
commit
80f28dbcd5
@ -18,6 +18,8 @@ extern "C"
|
||||
|
||||
void set_seed(int seed);
|
||||
|
||||
TorchTensorHandle empty_tensor();
|
||||
|
||||
double *get_data_double(TorchTensorHandle tensor_handle);
|
||||
float *get_data_float(TorchTensorHandle tensor_handle);
|
||||
long *get_data_long(TorchTensorHandle tensor_handle);
|
||||
@ -115,9 +117,21 @@ extern "C"
|
||||
|
||||
bool requires_grad(TorchTensorHandle tensor_handle);
|
||||
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
|
||||
void detach_from_graph(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable);
|
||||
|
||||
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2);
|
||||
|
||||
void svd_tensor(TorchTensorHandle tensor_handle,
|
||||
TorchTensorHandle U_handle,
|
||||
TorchTensorHandle S_handle,
|
||||
TorchTensorHandle V_handle);
|
||||
|
||||
void symeig_tensor(TorchTensorHandle tensor_handle,
|
||||
TorchTensorHandle S_handle,
|
||||
TorchTensorHandle V_handle,
|
||||
bool eigenvectors);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -25,6 +25,11 @@ void set_seed(int seed)
|
||||
torch::manual_seed(seed);
|
||||
}
|
||||
|
||||
TorchTensorHandle empty_tensor()
|
||||
{
|
||||
return new torch::Tensor;
|
||||
}
|
||||
|
||||
double *get_data_double(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<double>();
|
||||
@ -359,11 +364,37 @@ void requires_grad_(TorchTensorHandle tensor_handle, bool status)
|
||||
{
|
||||
ctorch::cast(tensor_handle).requires_grad_(status);
|
||||
}
|
||||
void detach_from_graph(TorchTensorHandle tensor_handle)
|
||||
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle).detach();
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).detach());
|
||||
}
|
||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable)
|
||||
{
|
||||
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]);
|
||||
}
|
||||
|
||||
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2)
|
||||
{
|
||||
return new torch::Tensor(torch::diag_embed(ctorch::cast(diags_handle), offset, dim1, dim2));
|
||||
}
|
||||
|
||||
void svd_tensor(TorchTensorHandle tensor_handle,
|
||||
TorchTensorHandle U_handle,
|
||||
TorchTensorHandle S_handle,
|
||||
TorchTensorHandle V_handle)
|
||||
{
|
||||
auto [U, S, V] = torch::svd(ctorch::cast(tensor_handle));
|
||||
ctorch::cast(U_handle) = U;
|
||||
ctorch::cast(S_handle) = S;
|
||||
ctorch::cast(V_handle) = V;
|
||||
}
|
||||
|
||||
void symeig_tensor(TorchTensorHandle tensor_handle,
|
||||
TorchTensorHandle S_handle,
|
||||
TorchTensorHandle V_handle,
|
||||
bool eigenvectors)
|
||||
{
|
||||
auto [S, V] = torch::symeig(ctorch::cast(tensor_handle), eigenvectors);
|
||||
ctorch::cast(S_handle) = S;
|
||||
ctorch::cast(V_handle) = V;
|
||||
}
|
||||
|
@ -6,4 +6,8 @@ import kotlin.test.*
|
||||
internal class TestAutogradGPU {
|
||||
@Test
|
||||
fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0))
|
||||
|
||||
@Test
|
||||
fun testBatchedAutoGrad() = testingBatchedAutoGrad(
|
||||
bath = intArrayOf(2), dim=3, device = TorchDevice.TorchCUDA(0))
|
||||
}
|
@ -21,4 +21,13 @@ class TestTorchTensorAlgebraGPU {
|
||||
fun testTensorTransformations() =
|
||||
testingTensorTransformations(device = TorchDevice.TorchCUDA(0))
|
||||
|
||||
@Test
|
||||
fun testBatchedSVD() =
|
||||
testingBatchedSVD(device = TorchDevice.TorchCUDA(0))
|
||||
|
||||
@Test
|
||||
fun testBatchedSymEig() =
|
||||
testingBatchedSymEig(device = TorchDevice.TorchCUDA(0))
|
||||
|
||||
|
||||
}
|
||||
|
@ -53,9 +53,6 @@ public sealed class TorchTensor<T> constructor(
|
||||
get() = requires_grad(tensorHandle)
|
||||
set(value) = requires_grad_(tensorHandle, value)
|
||||
|
||||
public fun detachFromGraph(): Unit {
|
||||
detach_from_graph(tensorHandle)
|
||||
}
|
||||
}
|
||||
|
||||
public class TorchTensorReal internal constructor(
|
||||
|
@ -84,6 +84,7 @@ public sealed class TorchTensorAlgebra<
|
||||
|
||||
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType =
|
||||
wrap(transpose_tensor(tensorHandle, i, j)!!)
|
||||
|
||||
public fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
|
||||
transpose_tensor_assign(tensorHandle, i, j)
|
||||
}
|
||||
@ -93,13 +94,37 @@ public sealed class TorchTensorAlgebra<
|
||||
sum_tensor_assign(tensorHandle)
|
||||
}
|
||||
|
||||
public fun diagEmbed(diagonalEntries: TorchTensorType,
|
||||
offset: Int = 0, dim1: Int = -2, dim2: Int = -1): TorchTensorType =
|
||||
wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!)
|
||||
|
||||
public fun TorchTensorType.svd(): Triple<TorchTensorType,TorchTensorType,TorchTensorType> {
|
||||
val U = empty_tensor()!!
|
||||
val V = empty_tensor()!!
|
||||
val S = empty_tensor()!!
|
||||
svd_tensor(this.tensorHandle, U, S, V)
|
||||
return Triple(wrap(U), wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair<TorchTensorType, TorchTensorType> {
|
||||
val V = empty_tensor()!!
|
||||
val S = empty_tensor()!!
|
||||
symeig_tensor(this.tensorHandle, S, V, eigenvectors)
|
||||
return Pair(wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
||||
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||
|
||||
public fun TorchTensorType.copy(): TorchTensorType =
|
||||
wrap(tensorHandle = copy_tensor(this.tensorHandle)!!)
|
||||
|
||||
public fun TorchTensorType.copyToDevice(device: TorchDevice): TorchTensorType =
|
||||
wrap(tensorHandle = copy_to_device(this.tensorHandle, device.toInt())!!)
|
||||
|
||||
public fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
||||
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
|
||||
|
||||
}
|
||||
|
||||
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
||||
|
@ -22,7 +22,36 @@ internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCP
|
||||
}
|
||||
}
|
||||
|
||||
internal fun testingBatchedAutoGrad(bath: IntArray,
|
||||
dim: Int,
|
||||
device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||
TorchTensorRealAlgebra {
|
||||
setSeed(SEED)
|
||||
|
||||
val tensorX = randNormal(shape = bath+intArrayOf(1,dim), device = device)
|
||||
tensorX.requiresGrad = true
|
||||
val tensorXt = tensorX.transpose(-1,-2)
|
||||
val randFeatures = randNormal(shape = bath+intArrayOf(dim, dim), device = device)
|
||||
val tensorSigma = randFeatures + randFeatures.transpose(-2,-1)
|
||||
val tensorMu = randNormal(shape = bath+intArrayOf(1,dim), device = device)
|
||||
|
||||
val expressionAtX =
|
||||
0.5 * (tensorX dot (tensorSigma dot tensorXt)) + (tensorMu dot tensorXt) + 58.2
|
||||
expressionAtX.sumAssign()
|
||||
|
||||
val gradientAtX = expressionAtX grad tensorX
|
||||
val expectedGradientAtX = (tensorX dot tensorSigma) + tensorMu
|
||||
|
||||
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
|
||||
assertTrue(error < TOLERANCE)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
internal class TestAutograd {
|
||||
@Test
|
||||
fun testAutoGrad() = testingAutoGrad(dim = 100)
|
||||
|
||||
@Test
|
||||
fun testBatchedAutoGrad() = testingBatchedAutoGrad(bath = intArrayOf(2,10), dim=30)
|
||||
}
|
@ -45,5 +45,8 @@ class TestTorchTensor {
|
||||
assertTrue(tensor.requiresGrad)
|
||||
tensor.requiresGrad = false
|
||||
assertTrue(!tensor.requiresGrad)
|
||||
tensor.requiresGrad = true
|
||||
val detachedTensor = tensor.detachFromGraph()
|
||||
assertTrue(!detachedTensor.requiresGrad)
|
||||
}
|
||||
}
|
@ -90,15 +90,34 @@ internal fun testingTensorTransformations(device: TorchDevice = TorchDevice.Torc
|
||||
val tensor = randNormal(shape = intArrayOf(3, 3), device = device)
|
||||
val result = tensor.exp().log()
|
||||
val assignResult = tensor.copy()
|
||||
assignResult.transposeAssign(0,1)
|
||||
assignResult.transposeAssign(0, 1)
|
||||
assignResult.expAssign()
|
||||
assignResult.logAssign()
|
||||
assignResult.transposeAssign(0,1)
|
||||
assignResult.transposeAssign(0, 1)
|
||||
val error = tensor - result
|
||||
error.absAssign()
|
||||
error.sumAssign()
|
||||
error += (tensor - assignResult).abs().sum()
|
||||
assertTrue (error.value()< TOLERANCE)
|
||||
assertTrue(error.value() < TOLERANCE)
|
||||
}
|
||||
}
|
||||
|
||||
internal fun testingBatchedSVD(device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||
TorchTensorRealAlgebra {
|
||||
val tensor = randNormal(shape = intArrayOf(7, 5, 3), device = device)
|
||||
val (tensorU, tensorS, tensorV) = tensor.svd()
|
||||
val error = tensor - (tensorU dot (diagEmbed(tensorS) dot tensorV.transpose(-2,-1)))
|
||||
assertTrue(error.abs().sum().value() < TOLERANCE)
|
||||
}
|
||||
}
|
||||
|
||||
internal fun testingBatchedSymEig(device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||
TorchTensorRealAlgebra {
|
||||
val tensor = randNormal(shape = intArrayOf(5,5), device = device)
|
||||
val tensorSigma = tensor + tensor.transpose(-2,-1)
|
||||
val (tensorS, tensorV) = tensorSigma.symEig()
|
||||
val error = tensorSigma - (tensorV dot (diagEmbed(tensorS) dot tensorV.transpose(-2,-1)))
|
||||
assertTrue(error.abs().sum().value() < TOLERANCE)
|
||||
}
|
||||
}
|
||||
|
||||
@ -116,4 +135,10 @@ internal class TestTorchTensorAlgebra {
|
||||
@Test
|
||||
fun testTensorTransformations() = testingTensorTransformations()
|
||||
|
||||
@Test
|
||||
fun testBatchedSVD() = testingBatchedSVD()
|
||||
|
||||
@Test
|
||||
fun testBatchedSymEig() = testingBatchedSymEig()
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user