Batched operations in algebra

This commit is contained in:
rgrit91 2021-01-13 15:15:14 +00:00
parent ca3cca65ef
commit 80f28dbcd5
9 changed files with 146 additions and 9 deletions

View File

@ -18,6 +18,8 @@ extern "C"
void set_seed(int seed); void set_seed(int seed);
TorchTensorHandle empty_tensor();
double *get_data_double(TorchTensorHandle tensor_handle); double *get_data_double(TorchTensorHandle tensor_handle);
float *get_data_float(TorchTensorHandle tensor_handle); float *get_data_float(TorchTensorHandle tensor_handle);
long *get_data_long(TorchTensorHandle tensor_handle); long *get_data_long(TorchTensorHandle tensor_handle);
@ -115,9 +117,21 @@ extern "C"
bool requires_grad(TorchTensorHandle tensor_handle); bool requires_grad(TorchTensorHandle tensor_handle);
void requires_grad_(TorchTensorHandle tensor_handle, bool status); void requires_grad_(TorchTensorHandle tensor_handle, bool status);
void detach_from_graph(TorchTensorHandle tensor_handle); TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle);
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable); TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable);
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2);
void svd_tensor(TorchTensorHandle tensor_handle,
TorchTensorHandle U_handle,
TorchTensorHandle S_handle,
TorchTensorHandle V_handle);
void symeig_tensor(TorchTensorHandle tensor_handle,
TorchTensorHandle S_handle,
TorchTensorHandle V_handle,
bool eigenvectors);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -25,6 +25,11 @@ void set_seed(int seed)
torch::manual_seed(seed); torch::manual_seed(seed);
} }
TorchTensorHandle empty_tensor()
{
return new torch::Tensor;
}
double *get_data_double(TorchTensorHandle tensor_handle) double *get_data_double(TorchTensorHandle tensor_handle)
{ {
return ctorch::cast(tensor_handle).data_ptr<double>(); return ctorch::cast(tensor_handle).data_ptr<double>();
@ -359,11 +364,37 @@ void requires_grad_(TorchTensorHandle tensor_handle, bool status)
{ {
ctorch::cast(tensor_handle).requires_grad_(status); ctorch::cast(tensor_handle).requires_grad_(status);
} }
void detach_from_graph(TorchTensorHandle tensor_handle) TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle)
{ {
ctorch::cast(tensor_handle).detach(); return new torch::Tensor(ctorch::cast(tensor_handle).detach());
} }
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable) TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable)
{ {
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]); return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]);
} }
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2)
{
return new torch::Tensor(torch::diag_embed(ctorch::cast(diags_handle), offset, dim1, dim2));
}
void svd_tensor(TorchTensorHandle tensor_handle,
TorchTensorHandle U_handle,
TorchTensorHandle S_handle,
TorchTensorHandle V_handle)
{
auto [U, S, V] = torch::svd(ctorch::cast(tensor_handle));
ctorch::cast(U_handle) = U;
ctorch::cast(S_handle) = S;
ctorch::cast(V_handle) = V;
}
void symeig_tensor(TorchTensorHandle tensor_handle,
TorchTensorHandle S_handle,
TorchTensorHandle V_handle,
bool eigenvectors)
{
auto [S, V] = torch::symeig(ctorch::cast(tensor_handle), eigenvectors);
ctorch::cast(S_handle) = S;
ctorch::cast(V_handle) = V;
}

View File

@ -6,4 +6,8 @@ import kotlin.test.*
internal class TestAutogradGPU { internal class TestAutogradGPU {
@Test @Test
fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0)) fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0))
@Test
fun testBatchedAutoGrad() = testingBatchedAutoGrad(
bath = intArrayOf(2), dim=3, device = TorchDevice.TorchCUDA(0))
} }

View File

@ -21,4 +21,13 @@ class TestTorchTensorAlgebraGPU {
fun testTensorTransformations() = fun testTensorTransformations() =
testingTensorTransformations(device = TorchDevice.TorchCUDA(0)) testingTensorTransformations(device = TorchDevice.TorchCUDA(0))
@Test
fun testBatchedSVD() =
testingBatchedSVD(device = TorchDevice.TorchCUDA(0))
@Test
fun testBatchedSymEig() =
testingBatchedSymEig(device = TorchDevice.TorchCUDA(0))
} }

View File

@ -53,9 +53,6 @@ public sealed class TorchTensor<T> constructor(
get() = requires_grad(tensorHandle) get() = requires_grad(tensorHandle)
set(value) = requires_grad_(tensorHandle, value) set(value) = requires_grad_(tensorHandle, value)
public fun detachFromGraph(): Unit {
detach_from_graph(tensorHandle)
}
} }
public class TorchTensorReal internal constructor( public class TorchTensorReal internal constructor(

View File

@ -84,6 +84,7 @@ public sealed class TorchTensorAlgebra<
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType = public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType =
wrap(transpose_tensor(tensorHandle, i, j)!!) wrap(transpose_tensor(tensorHandle, i, j)!!)
public fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit { public fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
transpose_tensor_assign(tensorHandle, i, j) transpose_tensor_assign(tensorHandle, i, j)
} }
@ -93,13 +94,37 @@ public sealed class TorchTensorAlgebra<
sum_tensor_assign(tensorHandle) sum_tensor_assign(tensorHandle)
} }
public fun diagEmbed(diagonalEntries: TorchTensorType,
offset: Int = 0, dim1: Int = -2, dim2: Int = -1): TorchTensorType =
wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!)
public fun TorchTensorType.svd(): Triple<TorchTensorType,TorchTensorType,TorchTensorType> {
val U = empty_tensor()!!
val V = empty_tensor()!!
val S = empty_tensor()!!
svd_tensor(this.tensorHandle, U, S, V)
return Triple(wrap(U), wrap(S), wrap(V))
}
public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair<TorchTensorType, TorchTensorType> {
val V = empty_tensor()!!
val S = empty_tensor()!!
symeig_tensor(this.tensorHandle, S, V, eigenvectors)
return Pair(wrap(S), wrap(V))
}
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType = public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!) wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
public fun TorchTensorType.copy(): TorchTensorType = public fun TorchTensorType.copy(): TorchTensorType =
wrap(tensorHandle = copy_tensor(this.tensorHandle)!!) wrap(tensorHandle = copy_tensor(this.tensorHandle)!!)
public fun TorchTensorType.copyToDevice(device: TorchDevice): TorchTensorType = public fun TorchTensorType.copyToDevice(device: TorchDevice): TorchTensorType =
wrap(tensorHandle = copy_to_device(this.tensorHandle, device.toInt())!!) wrap(tensorHandle = copy_to_device(this.tensorHandle, device.toInt())!!)
public fun TorchTensorType.detachFromGraph(): TorchTensorType =
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
} }
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar, public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,

View File

@ -22,7 +22,36 @@ internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCP
} }
} }
internal fun testingBatchedAutoGrad(bath: IntArray,
dim: Int,
device: TorchDevice = TorchDevice.TorchCPU): Unit {
TorchTensorRealAlgebra {
setSeed(SEED)
val tensorX = randNormal(shape = bath+intArrayOf(1,dim), device = device)
tensorX.requiresGrad = true
val tensorXt = tensorX.transpose(-1,-2)
val randFeatures = randNormal(shape = bath+intArrayOf(dim, dim), device = device)
val tensorSigma = randFeatures + randFeatures.transpose(-2,-1)
val tensorMu = randNormal(shape = bath+intArrayOf(1,dim), device = device)
val expressionAtX =
0.5 * (tensorX dot (tensorSigma dot tensorXt)) + (tensorMu dot tensorXt) + 58.2
expressionAtX.sumAssign()
val gradientAtX = expressionAtX grad tensorX
val expectedGradientAtX = (tensorX dot tensorSigma) + tensorMu
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
assertTrue(error < TOLERANCE)
}
}
internal class TestAutograd { internal class TestAutograd {
@Test @Test
fun testAutoGrad() = testingAutoGrad(dim = 100) fun testAutoGrad() = testingAutoGrad(dim = 100)
@Test
fun testBatchedAutoGrad() = testingBatchedAutoGrad(bath = intArrayOf(2,10), dim=30)
} }

View File

@ -45,5 +45,8 @@ class TestTorchTensor {
assertTrue(tensor.requiresGrad) assertTrue(tensor.requiresGrad)
tensor.requiresGrad = false tensor.requiresGrad = false
assertTrue(!tensor.requiresGrad) assertTrue(!tensor.requiresGrad)
tensor.requiresGrad = true
val detachedTensor = tensor.detachFromGraph()
assertTrue(!detachedTensor.requiresGrad)
} }
} }

View File

@ -90,15 +90,34 @@ internal fun testingTensorTransformations(device: TorchDevice = TorchDevice.Torc
val tensor = randNormal(shape = intArrayOf(3, 3), device = device) val tensor = randNormal(shape = intArrayOf(3, 3), device = device)
val result = tensor.exp().log() val result = tensor.exp().log()
val assignResult = tensor.copy() val assignResult = tensor.copy()
assignResult.transposeAssign(0,1) assignResult.transposeAssign(0, 1)
assignResult.expAssign() assignResult.expAssign()
assignResult.logAssign() assignResult.logAssign()
assignResult.transposeAssign(0,1) assignResult.transposeAssign(0, 1)
val error = tensor - result val error = tensor - result
error.absAssign() error.absAssign()
error.sumAssign() error.sumAssign()
error += (tensor - assignResult).abs().sum() error += (tensor - assignResult).abs().sum()
assertTrue (error.value()< TOLERANCE) assertTrue(error.value() < TOLERANCE)
}
}
internal fun testingBatchedSVD(device: TorchDevice = TorchDevice.TorchCPU): Unit {
TorchTensorRealAlgebra {
val tensor = randNormal(shape = intArrayOf(7, 5, 3), device = device)
val (tensorU, tensorS, tensorV) = tensor.svd()
val error = tensor - (tensorU dot (diagEmbed(tensorS) dot tensorV.transpose(-2,-1)))
assertTrue(error.abs().sum().value() < TOLERANCE)
}
}
internal fun testingBatchedSymEig(device: TorchDevice = TorchDevice.TorchCPU): Unit {
TorchTensorRealAlgebra {
val tensor = randNormal(shape = intArrayOf(5,5), device = device)
val tensorSigma = tensor + tensor.transpose(-2,-1)
val (tensorS, tensorV) = tensorSigma.symEig()
val error = tensorSigma - (tensorV dot (diagEmbed(tensorS) dot tensorV.transpose(-2,-1)))
assertTrue(error.abs().sum().value() < TOLERANCE)
} }
} }
@ -116,4 +135,10 @@ internal class TestTorchTensorAlgebra {
@Test @Test
fun testTensorTransformations() = testingTensorTransformations() fun testTensorTransformations() = testingTensorTransformations()
@Test
fun testBatchedSVD() = testingBatchedSVD()
@Test
fun testBatchedSymEig() = testingBatchedSymEig()
} }