forked from kscience/kmath
Batched operations in algebra
This commit is contained in:
parent
ca3cca65ef
commit
80f28dbcd5
@ -18,6 +18,8 @@ extern "C"
|
|||||||
|
|
||||||
void set_seed(int seed);
|
void set_seed(int seed);
|
||||||
|
|
||||||
|
TorchTensorHandle empty_tensor();
|
||||||
|
|
||||||
double *get_data_double(TorchTensorHandle tensor_handle);
|
double *get_data_double(TorchTensorHandle tensor_handle);
|
||||||
float *get_data_float(TorchTensorHandle tensor_handle);
|
float *get_data_float(TorchTensorHandle tensor_handle);
|
||||||
long *get_data_long(TorchTensorHandle tensor_handle);
|
long *get_data_long(TorchTensorHandle tensor_handle);
|
||||||
@ -115,9 +117,21 @@ extern "C"
|
|||||||
|
|
||||||
bool requires_grad(TorchTensorHandle tensor_handle);
|
bool requires_grad(TorchTensorHandle tensor_handle);
|
||||||
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
|
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
|
||||||
void detach_from_graph(TorchTensorHandle tensor_handle);
|
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle);
|
||||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable);
|
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable);
|
||||||
|
|
||||||
|
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2);
|
||||||
|
|
||||||
|
void svd_tensor(TorchTensorHandle tensor_handle,
|
||||||
|
TorchTensorHandle U_handle,
|
||||||
|
TorchTensorHandle S_handle,
|
||||||
|
TorchTensorHandle V_handle);
|
||||||
|
|
||||||
|
void symeig_tensor(TorchTensorHandle tensor_handle,
|
||||||
|
TorchTensorHandle S_handle,
|
||||||
|
TorchTensorHandle V_handle,
|
||||||
|
bool eigenvectors);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -25,6 +25,11 @@ void set_seed(int seed)
|
|||||||
torch::manual_seed(seed);
|
torch::manual_seed(seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle empty_tensor()
|
||||||
|
{
|
||||||
|
return new torch::Tensor;
|
||||||
|
}
|
||||||
|
|
||||||
double *get_data_double(TorchTensorHandle tensor_handle)
|
double *get_data_double(TorchTensorHandle tensor_handle)
|
||||||
{
|
{
|
||||||
return ctorch::cast(tensor_handle).data_ptr<double>();
|
return ctorch::cast(tensor_handle).data_ptr<double>();
|
||||||
@ -359,11 +364,37 @@ void requires_grad_(TorchTensorHandle tensor_handle, bool status)
|
|||||||
{
|
{
|
||||||
ctorch::cast(tensor_handle).requires_grad_(status);
|
ctorch::cast(tensor_handle).requires_grad_(status);
|
||||||
}
|
}
|
||||||
void detach_from_graph(TorchTensorHandle tensor_handle)
|
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle)
|
||||||
{
|
{
|
||||||
ctorch::cast(tensor_handle).detach();
|
return new torch::Tensor(ctorch::cast(tensor_handle).detach());
|
||||||
}
|
}
|
||||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable)
|
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]);
|
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(torch::diag_embed(ctorch::cast(diags_handle), offset, dim1, dim2));
|
||||||
|
}
|
||||||
|
|
||||||
|
void svd_tensor(TorchTensorHandle tensor_handle,
|
||||||
|
TorchTensorHandle U_handle,
|
||||||
|
TorchTensorHandle S_handle,
|
||||||
|
TorchTensorHandle V_handle)
|
||||||
|
{
|
||||||
|
auto [U, S, V] = torch::svd(ctorch::cast(tensor_handle));
|
||||||
|
ctorch::cast(U_handle) = U;
|
||||||
|
ctorch::cast(S_handle) = S;
|
||||||
|
ctorch::cast(V_handle) = V;
|
||||||
|
}
|
||||||
|
|
||||||
|
void symeig_tensor(TorchTensorHandle tensor_handle,
|
||||||
|
TorchTensorHandle S_handle,
|
||||||
|
TorchTensorHandle V_handle,
|
||||||
|
bool eigenvectors)
|
||||||
|
{
|
||||||
|
auto [S, V] = torch::symeig(ctorch::cast(tensor_handle), eigenvectors);
|
||||||
|
ctorch::cast(S_handle) = S;
|
||||||
|
ctorch::cast(V_handle) = V;
|
||||||
|
}
|
||||||
|
@ -6,4 +6,8 @@ import kotlin.test.*
|
|||||||
internal class TestAutogradGPU {
|
internal class TestAutogradGPU {
|
||||||
@Test
|
@Test
|
||||||
fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0))
|
fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBatchedAutoGrad() = testingBatchedAutoGrad(
|
||||||
|
bath = intArrayOf(2), dim=3, device = TorchDevice.TorchCUDA(0))
|
||||||
}
|
}
|
@ -21,4 +21,13 @@ class TestTorchTensorAlgebraGPU {
|
|||||||
fun testTensorTransformations() =
|
fun testTensorTransformations() =
|
||||||
testingTensorTransformations(device = TorchDevice.TorchCUDA(0))
|
testingTensorTransformations(device = TorchDevice.TorchCUDA(0))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBatchedSVD() =
|
||||||
|
testingBatchedSVD(device = TorchDevice.TorchCUDA(0))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBatchedSymEig() =
|
||||||
|
testingBatchedSymEig(device = TorchDevice.TorchCUDA(0))
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -53,9 +53,6 @@ public sealed class TorchTensor<T> constructor(
|
|||||||
get() = requires_grad(tensorHandle)
|
get() = requires_grad(tensorHandle)
|
||||||
set(value) = requires_grad_(tensorHandle, value)
|
set(value) = requires_grad_(tensorHandle, value)
|
||||||
|
|
||||||
public fun detachFromGraph(): Unit {
|
|
||||||
detach_from_graph(tensorHandle)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorReal internal constructor(
|
public class TorchTensorReal internal constructor(
|
||||||
|
@ -84,6 +84,7 @@ public sealed class TorchTensorAlgebra<
|
|||||||
|
|
||||||
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType =
|
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType =
|
||||||
wrap(transpose_tensor(tensorHandle, i, j)!!)
|
wrap(transpose_tensor(tensorHandle, i, j)!!)
|
||||||
|
|
||||||
public fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
|
public fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
|
||||||
transpose_tensor_assign(tensorHandle, i, j)
|
transpose_tensor_assign(tensorHandle, i, j)
|
||||||
}
|
}
|
||||||
@ -93,13 +94,37 @@ public sealed class TorchTensorAlgebra<
|
|||||||
sum_tensor_assign(tensorHandle)
|
sum_tensor_assign(tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public fun diagEmbed(diagonalEntries: TorchTensorType,
|
||||||
|
offset: Int = 0, dim1: Int = -2, dim2: Int = -1): TorchTensorType =
|
||||||
|
wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!)
|
||||||
|
|
||||||
|
public fun TorchTensorType.svd(): Triple<TorchTensorType,TorchTensorType,TorchTensorType> {
|
||||||
|
val U = empty_tensor()!!
|
||||||
|
val V = empty_tensor()!!
|
||||||
|
val S = empty_tensor()!!
|
||||||
|
svd_tensor(this.tensorHandle, U, S, V)
|
||||||
|
return Triple(wrap(U), wrap(S), wrap(V))
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair<TorchTensorType, TorchTensorType> {
|
||||||
|
val V = empty_tensor()!!
|
||||||
|
val S = empty_tensor()!!
|
||||||
|
symeig_tensor(this.tensorHandle, S, V, eigenvectors)
|
||||||
|
return Pair(wrap(S), wrap(V))
|
||||||
|
}
|
||||||
|
|
||||||
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
||||||
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||||
|
|
||||||
public fun TorchTensorType.copy(): TorchTensorType =
|
public fun TorchTensorType.copy(): TorchTensorType =
|
||||||
wrap(tensorHandle = copy_tensor(this.tensorHandle)!!)
|
wrap(tensorHandle = copy_tensor(this.tensorHandle)!!)
|
||||||
|
|
||||||
public fun TorchTensorType.copyToDevice(device: TorchDevice): TorchTensorType =
|
public fun TorchTensorType.copyToDevice(device: TorchDevice): TorchTensorType =
|
||||||
wrap(tensorHandle = copy_to_device(this.tensorHandle, device.toInt())!!)
|
wrap(tensorHandle = copy_to_device(this.tensorHandle, device.toInt())!!)
|
||||||
|
|
||||||
|
public fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
||||||
|
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
||||||
|
@ -22,7 +22,36 @@ internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCP
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal fun testingBatchedAutoGrad(bath: IntArray,
|
||||||
|
dim: Int,
|
||||||
|
device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||||
|
TorchTensorRealAlgebra {
|
||||||
|
setSeed(SEED)
|
||||||
|
|
||||||
|
val tensorX = randNormal(shape = bath+intArrayOf(1,dim), device = device)
|
||||||
|
tensorX.requiresGrad = true
|
||||||
|
val tensorXt = tensorX.transpose(-1,-2)
|
||||||
|
val randFeatures = randNormal(shape = bath+intArrayOf(dim, dim), device = device)
|
||||||
|
val tensorSigma = randFeatures + randFeatures.transpose(-2,-1)
|
||||||
|
val tensorMu = randNormal(shape = bath+intArrayOf(1,dim), device = device)
|
||||||
|
|
||||||
|
val expressionAtX =
|
||||||
|
0.5 * (tensorX dot (tensorSigma dot tensorXt)) + (tensorMu dot tensorXt) + 58.2
|
||||||
|
expressionAtX.sumAssign()
|
||||||
|
|
||||||
|
val gradientAtX = expressionAtX grad tensorX
|
||||||
|
val expectedGradientAtX = (tensorX dot tensorSigma) + tensorMu
|
||||||
|
|
||||||
|
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
|
||||||
|
assertTrue(error < TOLERANCE)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
internal class TestAutograd {
|
internal class TestAutograd {
|
||||||
@Test
|
@Test
|
||||||
fun testAutoGrad() = testingAutoGrad(dim = 100)
|
fun testAutoGrad() = testingAutoGrad(dim = 100)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBatchedAutoGrad() = testingBatchedAutoGrad(bath = intArrayOf(2,10), dim=30)
|
||||||
}
|
}
|
@ -45,5 +45,8 @@ class TestTorchTensor {
|
|||||||
assertTrue(tensor.requiresGrad)
|
assertTrue(tensor.requiresGrad)
|
||||||
tensor.requiresGrad = false
|
tensor.requiresGrad = false
|
||||||
assertTrue(!tensor.requiresGrad)
|
assertTrue(!tensor.requiresGrad)
|
||||||
|
tensor.requiresGrad = true
|
||||||
|
val detachedTensor = tensor.detachFromGraph()
|
||||||
|
assertTrue(!detachedTensor.requiresGrad)
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -102,6 +102,25 @@ internal fun testingTensorTransformations(device: TorchDevice = TorchDevice.Torc
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal fun testingBatchedSVD(device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||||
|
TorchTensorRealAlgebra {
|
||||||
|
val tensor = randNormal(shape = intArrayOf(7, 5, 3), device = device)
|
||||||
|
val (tensorU, tensorS, tensorV) = tensor.svd()
|
||||||
|
val error = tensor - (tensorU dot (diagEmbed(tensorS) dot tensorV.transpose(-2,-1)))
|
||||||
|
assertTrue(error.abs().sum().value() < TOLERANCE)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun testingBatchedSymEig(device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||||
|
TorchTensorRealAlgebra {
|
||||||
|
val tensor = randNormal(shape = intArrayOf(5,5), device = device)
|
||||||
|
val tensorSigma = tensor + tensor.transpose(-2,-1)
|
||||||
|
val (tensorS, tensorV) = tensorSigma.symEig()
|
||||||
|
val error = tensorSigma - (tensorV dot (diagEmbed(tensorS) dot tensorV.transpose(-2,-1)))
|
||||||
|
assertTrue(error.abs().sum().value() < TOLERANCE)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
internal class TestTorchTensorAlgebra {
|
internal class TestTorchTensorAlgebra {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -116,4 +135,10 @@ internal class TestTorchTensorAlgebra {
|
|||||||
@Test
|
@Test
|
||||||
fun testTensorTransformations() = testingTensorTransformations()
|
fun testTensorTransformations() = testingTensorTransformations()
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBatchedSVD() = testingBatchedSVD()
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBatchedSymEig() = testingBatchedSymEig()
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user