diff --git a/kmath-torch/ctorch/include/ctorch.h b/kmath-torch/ctorch/include/ctorch.h index e5a8562fe..44a7ad22e 100644 --- a/kmath-torch/ctorch/include/ctorch.h +++ b/kmath-torch/ctorch/include/ctorch.h @@ -18,6 +18,8 @@ extern "C" void set_seed(int seed); + TorchTensorHandle empty_tensor(); + double *get_data_double(TorchTensorHandle tensor_handle); float *get_data_float(TorchTensorHandle tensor_handle); long *get_data_long(TorchTensorHandle tensor_handle); @@ -115,9 +117,21 @@ extern "C" bool requires_grad(TorchTensorHandle tensor_handle); void requires_grad_(TorchTensorHandle tensor_handle, bool status); - void detach_from_graph(TorchTensorHandle tensor_handle); + TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle); TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable); + TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2); + + void svd_tensor(TorchTensorHandle tensor_handle, + TorchTensorHandle U_handle, + TorchTensorHandle S_handle, + TorchTensorHandle V_handle); + + void symeig_tensor(TorchTensorHandle tensor_handle, + TorchTensorHandle S_handle, + TorchTensorHandle V_handle, + bool eigenvectors); + #ifdef __cplusplus } #endif diff --git a/kmath-torch/ctorch/src/ctorch.cc b/kmath-torch/ctorch/src/ctorch.cc index 7148f9e15..24dff472d 100644 --- a/kmath-torch/ctorch/src/ctorch.cc +++ b/kmath-torch/ctorch/src/ctorch.cc @@ -25,6 +25,11 @@ void set_seed(int seed) torch::manual_seed(seed); } +TorchTensorHandle empty_tensor() +{ + return new torch::Tensor; +} + double *get_data_double(TorchTensorHandle tensor_handle) { return ctorch::cast(tensor_handle).data_ptr(); @@ -359,11 +364,37 @@ void requires_grad_(TorchTensorHandle tensor_handle, bool status) { ctorch::cast(tensor_handle).requires_grad_(status); } -void detach_from_graph(TorchTensorHandle tensor_handle) +TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle) { - ctorch::cast(tensor_handle).detach(); + return new torch::Tensor(ctorch::cast(tensor_handle).detach()); } TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable) { return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]); } + +TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2) +{ + return new torch::Tensor(torch::diag_embed(ctorch::cast(diags_handle), offset, dim1, dim2)); +} + +void svd_tensor(TorchTensorHandle tensor_handle, + TorchTensorHandle U_handle, + TorchTensorHandle S_handle, + TorchTensorHandle V_handle) +{ + auto [U, S, V] = torch::svd(ctorch::cast(tensor_handle)); + ctorch::cast(U_handle) = U; + ctorch::cast(S_handle) = S; + ctorch::cast(V_handle) = V; +} + +void symeig_tensor(TorchTensorHandle tensor_handle, + TorchTensorHandle S_handle, + TorchTensorHandle V_handle, + bool eigenvectors) +{ + auto [S, V] = torch::symeig(ctorch::cast(tensor_handle), eigenvectors); + ctorch::cast(S_handle) = S; + ctorch::cast(V_handle) = V; +} diff --git a/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestAutogradGPU.kt b/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestAutogradGPU.kt index 18f13099d..7c974fced 100644 --- a/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestAutogradGPU.kt +++ b/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestAutogradGPU.kt @@ -6,4 +6,8 @@ import kotlin.test.* internal class TestAutogradGPU { @Test fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0)) + + @Test + fun testBatchedAutoGrad() = testingBatchedAutoGrad( + bath = intArrayOf(2), dim=3, device = TorchDevice.TorchCUDA(0)) } \ No newline at end of file diff --git a/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebraGPU.kt b/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebraGPU.kt index dcdc4d7df..fdd5758c2 100644 --- a/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebraGPU.kt +++ b/kmath-torch/src/nativeGPUTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebraGPU.kt @@ -21,4 +21,13 @@ class TestTorchTensorAlgebraGPU { fun testTensorTransformations() = testingTensorTransformations(device = TorchDevice.TorchCUDA(0)) + @Test + fun testBatchedSVD() = + testingBatchedSVD(device = TorchDevice.TorchCUDA(0)) + + @Test + fun testBatchedSymEig() = + testingBatchedSymEig(device = TorchDevice.TorchCUDA(0)) + + } diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt index 953368b4a..01025d55f 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensor.kt @@ -53,9 +53,6 @@ public sealed class TorchTensor constructor( get() = requires_grad(tensorHandle) set(value) = requires_grad_(tensorHandle, value) - public fun detachFromGraph(): Unit { - detach_from_graph(tensorHandle) - } } public class TorchTensorReal internal constructor( diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt index 8f0f2c5e9..ab1af8769 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt @@ -84,6 +84,7 @@ public sealed class TorchTensorAlgebra< public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType = wrap(transpose_tensor(tensorHandle, i, j)!!) + public fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit { transpose_tensor_assign(tensorHandle, i, j) } @@ -93,13 +94,37 @@ public sealed class TorchTensorAlgebra< sum_tensor_assign(tensorHandle) } + public fun diagEmbed(diagonalEntries: TorchTensorType, + offset: Int = 0, dim1: Int = -2, dim2: Int = -1): TorchTensorType = + wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!) + + public fun TorchTensorType.svd(): Triple { + val U = empty_tensor()!! + val V = empty_tensor()!! + val S = empty_tensor()!! + svd_tensor(this.tensorHandle, U, S, V) + return Triple(wrap(U), wrap(S), wrap(V)) + } + + public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair { + val V = empty_tensor()!! + val S = empty_tensor()!! + symeig_tensor(this.tensorHandle, S, V, eigenvectors) + return Pair(wrap(S), wrap(V)) + } + public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType = wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!) public fun TorchTensorType.copy(): TorchTensorType = wrap(tensorHandle = copy_tensor(this.tensorHandle)!!) + public fun TorchTensorType.copyToDevice(device: TorchDevice): TorchTensorType = wrap(tensorHandle = copy_to_device(this.tensorHandle, device.toInt())!!) + + public fun TorchTensorType.detachFromGraph(): TorchTensorType = + wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!) + } public sealed class TorchTensorFieldAlgebra