diff --git a/kmath-torch/src/commonTest/kotlin/kscience.kmath.torch/TestTorchTensor.kt b/kmath-torch/src/commonTest/kotlin/kscience.kmath.torch/TestTorchTensor.kt index a35d154a9..f78c788b8 100644 --- a/kmath-torch/src/commonTest/kotlin/kscience.kmath.torch/TestTorchTensor.kt +++ b/kmath-torch/src/commonTest/kotlin/kscience.kmath.torch/TestTorchTensor.kt @@ -12,9 +12,9 @@ internal inline fun , val shape = intArrayOf(2, 3, 4) val tensor = copyFromArray(array, shape = shape, device = device) val copyOfTensor = tensor.copy() - tensor[intArrayOf(0, 0)] = 0.1f + tensor[intArrayOf(1, 2, 3)] = 0.1f assertTrue(copyOfTensor.copyToArray() contentEquals array) - assertEquals(0.1f, tensor[intArrayOf(0, 0)]) + assertEquals(0.1f, tensor[intArrayOf(1, 2, 3)]) if(device != Device.CPU){ val normalCpu = randNormal(intArrayOf(2, 3)) val normalGpu = normalCpu.copyToDevice(device) diff --git a/kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/TorchTensorAlgebraJVM.kt b/kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/TorchTensorAlgebraJVM.kt new file mode 100644 index 000000000..ef9517ed1 --- /dev/null +++ b/kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/TorchTensorAlgebraJVM.kt @@ -0,0 +1,386 @@ +package kscience.kmath.torch + +import kscience.kmath.memory.DeferScope +import kscience.kmath.memory.withDeferScope + +public sealed class TorchTensorAlgebraJVM< + T, + PrimitiveArrayType, + TorchTensorType : TorchTensorJVM> constructor( + internal val scope: DeferScope +) : TorchTensorAlgebra { + override fun getNumThreads(): Int { + return JTorch.getNumThreads() + } + + override fun setNumThreads(numThreads: Int): Unit { + JTorch.setNumThreads(numThreads) + } + + override fun cudaAvailable(): Boolean { + return JTorch.cudaIsAvailable() + } + + override fun setSeed(seed: Int): Unit { + JTorch.setSeed(seed) + } + + override var checks: Boolean = false + + internal abstract fun wrap(tensorHandle: Long): TorchTensorType + + override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType { + if (checks) checkLinearOperation(this, other) + return wrap(JTorch.timesTensor(this.tensorHandle, other.tensorHandle)) + } + + override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit { + if (checks) checkLinearOperation(this, other) + JTorch.timesTensorAssign(this.tensorHandle, other.tensorHandle) + } + + override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType { + if (checks) checkLinearOperation(this, other) + return wrap(JTorch.plusTensor(this.tensorHandle, other.tensorHandle)) + } + + override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit { + if (checks) checkLinearOperation(this, other) + JTorch.plusTensorAssign(this.tensorHandle, other.tensorHandle) + } + + override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType { + if (checks) checkLinearOperation(this, other) + return wrap(JTorch.minusTensor(this.tensorHandle, other.tensorHandle)) + } + + override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit { + if (checks) checkLinearOperation(this, other) + JTorch.minusTensorAssign(this.tensorHandle, other.tensorHandle) + } + + override operator fun TorchTensorType.unaryMinus(): TorchTensorType = + wrap(JTorch.unaryMinus(this.tensorHandle)) + + override infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType { + if (checks) checkDotOperation(this, other) + return wrap(JTorch.matmul(this.tensorHandle, other.tensorHandle)) + } + + override infix fun TorchTensorType.dotAssign(other: TorchTensorType): Unit { + if (checks) checkDotOperation(this, other) + JTorch.matmulAssign(this.tensorHandle, other.tensorHandle) + } + + override infix fun TorchTensorType.dotRightAssign(other: TorchTensorType): Unit { + if (checks) checkDotOperation(this, other) + JTorch.matmulRightAssign(this.tensorHandle, other.tensorHandle) + } + + override fun diagonalEmbedding( + diagonalEntries: TorchTensorType, offset: Int, dim1: Int, dim2: Int + ): TorchTensorType = + wrap(JTorch.diagEmbed(diagonalEntries.tensorHandle, offset, dim1, dim2)) + + override fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType { + if (checks) checkTranspose(this.dimension, i, j) + return wrap(JTorch.transposeTensor(tensorHandle, i, j)) + } + + override fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit { + if (checks) checkTranspose(this.dimension, i, j) + JTorch.transposeTensorAssign(tensorHandle, i, j) + } + + override fun TorchTensorType.view(shape: IntArray): TorchTensorType { + if (checks) checkView(this, shape) + return wrap(JTorch.viewTensor(this.tensorHandle, shape)) + } + + override fun TorchTensorType.abs(): TorchTensorType = wrap(JTorch.absTensor(tensorHandle)) + override fun TorchTensorType.absAssign(): Unit = JTorch.absTensorAssign(tensorHandle) + + override fun TorchTensorType.sum(): TorchTensorType = wrap(JTorch.sumTensor(tensorHandle)) + override fun TorchTensorType.sumAssign(): Unit = JTorch.sumTensorAssign(tensorHandle) + + override fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType = + wrap(JTorch.randintLike(this.tensorHandle, low, high)) + + override fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit = + JTorch.randintLikeAssign(this.tensorHandle, low, high) + + override fun TorchTensorType.copy(): TorchTensorType = + wrap(JTorch.copyTensor(this.tensorHandle)) + + override fun TorchTensorType.copyToDevice(device: Device): TorchTensorType = + wrap(JTorch.copyToDevice(this.tensorHandle, device.toInt())) + + override infix fun TorchTensorType.swap(other: TorchTensorType): Unit = + JTorch.swapTensors(this.tensorHandle, other.tensorHandle) +} + +public sealed class TorchTensorPartialDivisionAlgebraJVM>(scope: DeferScope) : + TorchTensorAlgebraJVM(scope), + TorchTensorPartialDivisionAlgebra { + + override operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType { + if (checks) checkLinearOperation(this, other) + return wrap(JTorch.divTensor(this.tensorHandle, other.tensorHandle)) + } + + override operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit { + if (checks) checkLinearOperation(this, other) + JTorch.divTensorAssign(this.tensorHandle, other.tensorHandle) + } + + override fun TorchTensorType.randUniform(): TorchTensorType = + wrap(JTorch.randLike(this.tensorHandle)) + + override fun TorchTensorType.randUniformAssign(): Unit = + JTorch.randLikeAssign(this.tensorHandle) + + override fun TorchTensorType.randNormal(): TorchTensorType = + wrap(JTorch.randnLike(this.tensorHandle)) + + override fun TorchTensorType.randNormalAssign(): Unit = + JTorch.randnLikeAssign(this.tensorHandle) + + override fun TorchTensorType.exp(): TorchTensorType = wrap(JTorch.expTensor(tensorHandle)) + override fun TorchTensorType.expAssign(): Unit = JTorch.expTensorAssign(tensorHandle) + override fun TorchTensorType.log(): TorchTensorType = wrap(JTorch.logTensor(tensorHandle)) + override fun TorchTensorType.logAssign(): Unit = JTorch.logTensorAssign(tensorHandle) + + override fun TorchTensorType.svd(): Triple { + val U = JTorch.emptyTensor() + val V = JTorch.emptyTensor() + val S = JTorch.emptyTensor() + JTorch.svdTensor(this.tensorHandle, U, S, V) + return Triple(wrap(U), wrap(S), wrap(V)) + } + + override fun TorchTensorType.symEig(eigenvectors: Boolean): Pair { + val V = JTorch.emptyTensor() + val S = JTorch.emptyTensor() + JTorch.symeigTensor(this.tensorHandle, S, V, eigenvectors) + return Pair(wrap(S), wrap(V)) + } + + override fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean): TorchTensorType { + if (checks) this.checkIsValue() + return wrap(JTorch.autogradTensor(this.tensorHandle, variable.tensorHandle, retainGraph)) + } + + override infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType { + if (checks) this.checkIsValue() + return wrap(JTorch.autohessTensor(this.tensorHandle, variable.tensorHandle)) + } + + override fun TorchTensorType.detachFromGraph(): TorchTensorType = + wrap(JTorch.detachFromGraph(this.tensorHandle)) + +} + +public class TorchTensorRealAlgebra(scope: DeferScope) : + TorchTensorPartialDivisionAlgebraJVM(scope) { + override fun wrap(tensorHandle: Long): TorchTensorReal = + TorchTensorReal(scope = scope, tensorHandle = tensorHandle) + + override fun TorchTensorReal.copyToArray(): DoubleArray = + this.elements().map { it.second }.toList().toDoubleArray() + + override fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): TorchTensorReal = + wrap(JTorch.fromBlobDouble(array, shape, device.toInt())) + + override fun randNormal(shape: IntArray, device: Device): TorchTensorReal = + wrap(JTorch.randnDouble(shape, device.toInt())) + + override fun randUniform(shape: IntArray, device: Device): TorchTensorReal = + wrap(JTorch.randDouble(shape, device.toInt())) + + override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorReal = + wrap(JTorch.randintDouble(low, high, shape, device.toInt())) + + override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal = + wrap(JTorch.plusDouble(this, other.tensorHandle)) + + override fun TorchTensorReal.plus(value: Double): TorchTensorReal = + wrap(JTorch.plusDouble(value, this.tensorHandle)) + + override fun TorchTensorReal.plusAssign(value: Double): Unit = + JTorch.plusDoubleAssign(value, this.tensorHandle) + + override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal = + wrap(JTorch.plusDouble(-this, other.tensorHandle)) + + override fun TorchTensorReal.minus(value: Double): TorchTensorReal = + wrap(JTorch.plusDouble(-value, this.tensorHandle)) + + override fun TorchTensorReal.minusAssign(value: Double): Unit = + JTorch.plusDoubleAssign(-value, this.tensorHandle) + + override operator fun Double.times(other: TorchTensorReal): TorchTensorReal = + wrap(JTorch.timesDouble(this, other.tensorHandle)) + + override fun TorchTensorReal.times(value: Double): TorchTensorReal = + wrap(JTorch.timesDouble(value, this.tensorHandle)) + + override fun TorchTensorReal.timesAssign(value: Double): Unit = + JTorch.timesDoubleAssign(value, this.tensorHandle) + + override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal = + wrap(JTorch.fullDouble(value, shape, device.toInt())) +} + +public class TorchTensorFloatAlgebra(scope: DeferScope) : + TorchTensorPartialDivisionAlgebraJVM(scope) { + override fun wrap(tensorHandle: Long): TorchTensorFloat = + TorchTensorFloat(scope = scope, tensorHandle = tensorHandle) + + override fun TorchTensorFloat.copyToArray(): FloatArray = + this.elements().map { it.second }.toList().toFloatArray() + + override fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): TorchTensorFloat = + wrap(JTorch.fromBlobFloat(array, shape, device.toInt())) + + override fun randNormal(shape: IntArray, device: Device): TorchTensorFloat = + wrap(JTorch.randnFloat(shape, device.toInt())) + + override fun randUniform(shape: IntArray, device: Device): TorchTensorFloat = + wrap(JTorch.randFloat(shape, device.toInt())) + + override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorFloat = + wrap(JTorch.randintFloat(low, high, shape, device.toInt())) + + override operator fun Float.plus(other: TorchTensorFloat): TorchTensorFloat = + wrap(JTorch.plusFloat(this, other.tensorHandle)) + + override fun TorchTensorFloat.plus(value: Float): TorchTensorFloat = + wrap(JTorch.plusFloat(value, this.tensorHandle)) + + override fun TorchTensorFloat.plusAssign(value: Float): Unit = + JTorch.plusFloatAssign(value, this.tensorHandle) + + override operator fun Float.minus(other: TorchTensorFloat): TorchTensorFloat = + wrap(JTorch.plusFloat(-this, other.tensorHandle)) + + override fun TorchTensorFloat.minus(value: Float): TorchTensorFloat = + wrap(JTorch.plusFloat(-value, this.tensorHandle)) + + override fun TorchTensorFloat.minusAssign(value: Float): Unit = + JTorch.plusFloatAssign(-value, this.tensorHandle) + + override operator fun Float.times(other: TorchTensorFloat): TorchTensorFloat = + wrap(JTorch.timesFloat(this, other.tensorHandle)) + + override fun TorchTensorFloat.times(value: Float): TorchTensorFloat = + wrap(JTorch.timesFloat(value, this.tensorHandle)) + + override fun TorchTensorFloat.timesAssign(value: Float): Unit = + JTorch.timesFloatAssign(value, this.tensorHandle) + + override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat = + wrap(JTorch.fullFloat(value, shape, device.toInt())) +} + +public class TorchTensorLongAlgebra(scope: DeferScope) : + TorchTensorAlgebraJVM(scope) { + override fun wrap(tensorHandle: Long): TorchTensorLong = + TorchTensorLong(scope = scope, tensorHandle = tensorHandle) + + override fun TorchTensorLong.copyToArray(): LongArray = + this.elements().map { it.second }.toList().toLongArray() + + override fun copyFromArray(array: LongArray, shape: IntArray, device: Device): TorchTensorLong = + wrap(JTorch.fromBlobLong(array, shape, device.toInt())) + + override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorLong = + wrap(JTorch.randintLong(low, high, shape, device.toInt())) + + override operator fun Long.plus(other: TorchTensorLong): TorchTensorLong = + wrap(JTorch.plusLong(this, other.tensorHandle)) + + override fun TorchTensorLong.plus(value: Long): TorchTensorLong = + wrap(JTorch.plusLong(value, this.tensorHandle)) + + override fun TorchTensorLong.plusAssign(value: Long): Unit = + JTorch.plusLongAssign(value, this.tensorHandle) + + override operator fun Long.minus(other: TorchTensorLong): TorchTensorLong = + wrap(JTorch.plusLong(-this, other.tensorHandle)) + + override fun TorchTensorLong.minus(value: Long): TorchTensorLong = + wrap(JTorch.plusLong(-value, this.tensorHandle)) + + override fun TorchTensorLong.minusAssign(value: Long): Unit = + JTorch.plusLongAssign(-value, this.tensorHandle) + + override operator fun Long.times(other: TorchTensorLong): TorchTensorLong = + wrap(JTorch.timesLong(this, other.tensorHandle)) + + override fun TorchTensorLong.times(value: Long): TorchTensorLong = + wrap(JTorch.timesLong(value, this.tensorHandle)) + + override fun TorchTensorLong.timesAssign(value: Long): Unit = + JTorch.timesLongAssign(value, this.tensorHandle) + + override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong = + wrap(JTorch.fullLong(value, shape, device.toInt())) +} + +public class TorchTensorIntAlgebra(scope: DeferScope) : + TorchTensorAlgebraJVM(scope) { + override fun wrap(tensorHandle: Long): TorchTensorInt = + TorchTensorInt(scope = scope, tensorHandle = tensorHandle) + + override fun TorchTensorInt.copyToArray(): IntArray = + this.elements().map { it.second }.toList().toIntArray() + + override fun copyFromArray(array: IntArray, shape: IntArray, device: Device): TorchTensorInt = + wrap(JTorch.fromBlobInt(array, shape, device.toInt())) + + override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorInt = + wrap(JTorch.randintInt(low, high, shape, device.toInt())) + + override operator fun Int.plus(other: TorchTensorInt): TorchTensorInt = + wrap(JTorch.plusInt(this, other.tensorHandle)) + + override fun TorchTensorInt.plus(value: Int): TorchTensorInt = + wrap(JTorch.plusInt(value, this.tensorHandle)) + + override fun TorchTensorInt.plusAssign(value: Int): Unit = + JTorch.plusIntAssign(value, this.tensorHandle) + + override operator fun Int.minus(other: TorchTensorInt): TorchTensorInt = + wrap(JTorch.plusInt(-this, other.tensorHandle)) + + override fun TorchTensorInt.minus(value: Int): TorchTensorInt = + wrap(JTorch.plusInt(-value, this.tensorHandle)) + + override fun TorchTensorInt.minusAssign(value: Int): Unit = + JTorch.plusIntAssign(-value, this.tensorHandle) + + override operator fun Int.times(other: TorchTensorInt): TorchTensorInt = + wrap(JTorch.timesInt(this, other.tensorHandle)) + + override fun TorchTensorInt.times(value: Int): TorchTensorInt = + wrap(JTorch.timesInt(value, this.tensorHandle)) + + override fun TorchTensorInt.timesAssign(value: Int): Unit = + JTorch.timesIntAssign(value, this.tensorHandle) + + override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt = + wrap(JTorch.fullInt(value, shape, device.toInt())) +} + +public inline fun TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R = + withDeferScope { TorchTensorRealAlgebra(this).block() } + +public inline fun TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R = + withDeferScope { TorchTensorFloatAlgebra(this).block() } + +public inline fun TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R = + withDeferScope { TorchTensorLongAlgebra(this).block() } + +public inline fun TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R = + withDeferScope { TorchTensorIntAlgebra(this).block() } \ No newline at end of file diff --git a/kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/TorchTensorJVM.kt b/kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/TorchTensorJVM.kt index 0064b33b3..cbd56d884 100644 --- a/kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/TorchTensorJVM.kt +++ b/kmath-torch/src/jvmMain/kotlin/kscience/kmath/torch/TorchTensorJVM.kt @@ -1,4 +1,94 @@ package kscience.kmath.torch -public class TorchTensorJVM { +import kscience.kmath.memory.DeferScope + +public sealed class TorchTensorJVM constructor( + scope: DeferScope, + internal val tensorHandle: Long +) : TorchTensor, TorchTensorMemoryHolder(scope) +{ + override fun close(): Unit = JTorch.disposeTensor(tensorHandle) + + override val dimension: Int get() = JTorch.getDim(tensorHandle) + override val shape: IntArray + get() = (1..dimension).map { JTorch.getShapeAt(tensorHandle, it - 1) }.toIntArray() + override val strides: IntArray + get() = (1..dimension).map { JTorch.getStrideAt(tensorHandle, it - 1) }.toIntArray() + override val size: Int get() = JTorch.getNumel(tensorHandle) + override val device: Device get() = Device.fromInt(JTorch.getDevice(tensorHandle)) + + override fun toString(): String = JTorch.tensorToString(tensorHandle) + + public fun copyToDouble(): TorchTensorReal = TorchTensorReal( + scope = scope, + tensorHandle = JTorch.copyToDouble(this.tensorHandle) + ) + + public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat( + scope = scope, + tensorHandle = JTorch.copyToFloat(this.tensorHandle) + ) + + public fun copyToLong(): TorchTensorLong = TorchTensorLong( + scope = scope, + tensorHandle = JTorch.copyToLong(this.tensorHandle) + ) + + public fun copyToInt(): TorchTensorInt = TorchTensorInt( + scope = scope, + tensorHandle = JTorch.copyToInt(this.tensorHandle) + ) +} + +public sealed class TorchTensorOverFieldJVM constructor( + scope: DeferScope, + tensorHandle: Long +) : TorchTensorJVM(scope, tensorHandle), TorchTensorOverField { + override var requiresGrad: Boolean + get() = JTorch.requiresGrad(tensorHandle) + set(value) = JTorch.setRequiresGrad(tensorHandle, value) +} + +public class TorchTensorReal internal constructor( + scope: DeferScope, + tensorHandle: Long +) : TorchTensorOverFieldJVM(scope, tensorHandle) { + override fun item(): Double = JTorch.getItemDouble(tensorHandle) + override fun get(index: IntArray): Double = JTorch.getDouble(tensorHandle, index) + override fun set(index: IntArray, value: Double) { + JTorch.setDouble(tensorHandle, index, value) + } +} + +public class TorchTensorFloat internal constructor( + scope: DeferScope, + tensorHandle: Long +) : TorchTensorOverFieldJVM(scope, tensorHandle) { + override fun item(): Float = JTorch.getItemFloat(tensorHandle) + override fun get(index: IntArray): Float = JTorch.getFloat(tensorHandle, index) + override fun set(index: IntArray, value: Float) { + JTorch.setFloat(tensorHandle, index, value) + } +} + +public class TorchTensorLong internal constructor( + scope: DeferScope, + tensorHandle: Long +) : TorchTensorOverFieldJVM(scope, tensorHandle) { + override fun item(): Long = JTorch.getItemLong(tensorHandle) + override fun get(index: IntArray): Long = JTorch.getLong(tensorHandle, index) + override fun set(index: IntArray, value: Long) { + JTorch.setLong(tensorHandle, index, value) + } +} + +public class TorchTensorInt internal constructor( + scope: DeferScope, + tensorHandle: Long +) : TorchTensorOverFieldJVM(scope, tensorHandle) { + override fun item(): Int = JTorch.getItemInt(tensorHandle) + override fun get(index: IntArray): Int = JTorch.getInt(tensorHandle, index) + override fun set(index: IntArray, value: Int) { + JTorch.setInt(tensorHandle, index, value) + } } \ No newline at end of file diff --git a/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/BenchmarkMatMul.kt b/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/BenchmarkMatMul.kt new file mode 100644 index 000000000..2fa881c8b --- /dev/null +++ b/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/BenchmarkMatMul.kt @@ -0,0 +1,26 @@ +package kscience.kmath.torch + +import kotlin.test.Test + + +class BenchmarkMatMul { + + @Test + fun benchmarkMatMulDouble() = TorchTensorRealAlgebra { + benchmarkMatMul(20, 10, 100000, "Real") + benchmarkMatMul(200, 10, 10000, "Real") + benchmarkMatMul(2000, 3, 20, "Real") + } + + @Test + fun benchmarkMatMulFloat() = TorchTensorFloatAlgebra { + benchmarkMatMul(20, 10, 100000, "Float") + benchmarkMatMul(200, 10, 10000, "Float") + benchmarkMatMul(2000, 3, 20, "Float") + if (cudaAvailable()) { + benchmarkMatMul(20, 10, 100000, "Float", Device.CUDA(0)) + benchmarkMatMul(200, 10, 10000, "Float", Device.CUDA(0)) + benchmarkMatMul(2000, 10, 1000, "Float", Device.CUDA(0)) + } + } +} \ No newline at end of file diff --git a/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/BenchmarkRandomGenerators.kt b/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/BenchmarkRandomGenerators.kt new file mode 100644 index 000000000..d76a52edf --- /dev/null +++ b/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/BenchmarkRandomGenerators.kt @@ -0,0 +1,27 @@ +package kscience.kmath.torch + +import kotlin.test.Test + + +class BenchmarkRandomGenerators { + @Test + fun benchmarkRand1() = TorchTensorFloatAlgebra{ + benchmarkingRand1() + } + + @Test + fun benchmarkRand3() = TorchTensorFloatAlgebra{ + benchmarkingRand3() + } + + @Test + fun benchmarkRand5() = TorchTensorFloatAlgebra{ + benchmarkingRand5() + } + + @Test + fun benchmarkRand7() = TorchTensorFloatAlgebra{ + benchmarkingRand7() + } + +} \ No newline at end of file diff --git a/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestAutograd.kt b/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestAutograd.kt new file mode 100644 index 000000000..9ffc7115e --- /dev/null +++ b/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestAutograd.kt @@ -0,0 +1,24 @@ +package kscience.kmath.torch + +import kotlin.test.Test + + +class TestAutograd { + @Test + fun testAutoGrad() = TorchTensorFloatAlgebra { + withChecks { + withCuda { device -> + testingAutoGrad(device) + } + } + } + + @Test + fun testBatchedAutoGrad() = TorchTensorFloatAlgebra { + withChecks { + withCuda { device -> + testingBatchedAutoGrad(device) + } + } + } +} \ No newline at end of file diff --git a/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestTorchTensor.kt b/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestTorchTensor.kt new file mode 100644 index 000000000..528da3e40 --- /dev/null +++ b/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestTorchTensor.kt @@ -0,0 +1,39 @@ +package kscience.kmath.torch + +import kotlin.test.* + + +class TestTorchTensor { + + @Test + fun testCopying() = TorchTensorFloatAlgebra { + withCuda { device -> + testingCopying(device) + } + } + + @Test + fun testRequiresGrad() = TorchTensorRealAlgebra { + testingRequiresGrad() + } + + @Test + fun testTypeMoving() = TorchTensorFloatAlgebra { + val tensorInt = copyFromArray(floatArrayOf(1f, 2f, 3f), intArrayOf(3)).copyToInt() + TorchTensorIntAlgebra { + val temporalTensor = copyFromArray(intArrayOf(4, 5, 6), intArrayOf(3)) + tensorInt swap temporalTensor + assertTrue(temporalTensor.copyToArray() contentEquals intArrayOf(1, 2, 3)) + } + assertTrue(tensorInt.copyToFloat().copyToArray() contentEquals floatArrayOf(4f, 5f, 6f)) + } + + @Test + fun testViewWithNoCopy() = TorchTensorIntAlgebra { + withChecks { + withCuda { + device -> testingViewWithNoCopy(device) + } + } + } +} \ No newline at end of file diff --git a/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt b/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt new file mode 100644 index 000000000..810cf7399 --- /dev/null +++ b/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt @@ -0,0 +1,63 @@ +package kscience.kmath.torch + +import kotlin.test.Test + + +class TestTorchTensorAlgebra { + + @Test + fun testScalarProduct() = TorchTensorRealAlgebra { + withChecks { + withCuda { device -> + testingScalarProduct(device) + } + } + } + + @Test + fun testMatrixMultiplication() = TorchTensorRealAlgebra { + withChecks { + withCuda { device -> + testingMatrixMultiplication(device) + } + } + } + + @Test + fun testLinearStructure() = TorchTensorRealAlgebra { + withChecks { + withCuda { device -> + testingLinearStructure(device) + } + } + } + + @Test + fun testTensorTransformations() = TorchTensorRealAlgebra { + withChecks { + withCuda { device -> + testingTensorTransformations(device) + } + } + } + + @Test + fun testBatchedSVD() = TorchTensorRealAlgebra { + withChecks { + withCuda { device -> + testingBatchedSVD(device) + } + } + } + + @Test + fun testBatchedSymEig() = TorchTensorRealAlgebra { + withChecks { + withCuda { device -> + testingBatchedSymEig(device) + } + } + } + + +} \ No newline at end of file diff --git a/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestUtils.kt b/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestUtils.kt index 2371b576e..0b5c45a2b 100644 --- a/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestUtils.kt +++ b/kmath-torch/src/jvmTest/kotlin/kscience/kmath/torch/TestUtils.kt @@ -6,10 +6,16 @@ import kotlin.test.* class TestUtils { @Test - fun testJTorch() { - val tensor = JTorch.fullInt(54, intArrayOf(3), 0) - println(JTorch.tensorToString(tensor)) - JTorch.disposeTensor(tensor) + fun testSetNumThreads() { + TorchTensorLongAlgebra { + testingSetNumThreads() + } } + @Test + fun testSeedSetting() = TorchTensorFloatAlgebra { + withCuda { device -> + testingSetSeed(device) + } + } } \ No newline at end of file diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebraNative.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebraNative.kt index 93987fcf0..457da024e 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebraNative.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebraNative.kt @@ -31,7 +31,6 @@ public sealed class TorchTensorAlgebraNative< set_seed(seed) } - override var checks: Boolean = false internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensorType @@ -108,21 +107,16 @@ public sealed class TorchTensorAlgebraNative< } override fun TorchTensorType.abs(): TorchTensorType = wrap(abs_tensor(tensorHandle)!!) - override fun TorchTensorType.absAssign(): Unit { - abs_tensor_assign(tensorHandle) - } + override fun TorchTensorType.absAssign(): Unit = abs_tensor_assign(tensorHandle) override fun TorchTensorType.sum(): TorchTensorType = wrap(sum_tensor(tensorHandle)!!) - override fun TorchTensorType.sumAssign(): Unit { - sum_tensor_assign(tensorHandle) - } + override fun TorchTensorType.sumAssign(): Unit = sum_tensor_assign(tensorHandle) override fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType = wrap(randint_like(this.tensorHandle, low, high)!!) - override fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit { + override fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit = randint_like_assign(this.tensorHandle, low, high) - } override fun TorchTensorType.copy(): TorchTensorType = wrap(copy_tensor(this.tensorHandle)!!) @@ -130,9 +124,8 @@ public sealed class TorchTensorAlgebraNative< override fun TorchTensorType.copyToDevice(device: Device): TorchTensorType = wrap(copy_to_device(this.tensorHandle, device.toInt())!!) - override infix fun TorchTensorType.swap(other: TorchTensorType): Unit { + override infix fun TorchTensorType.swap(other: TorchTensorType): Unit = swap_tensors(this.tensorHandle, other.tensorHandle) - } } public sealed class TorchTensorPartialDivisionAlgebraNative { val U = empty_tensor()!! @@ -200,7 +188,7 @@ public sealed class TorchTensorPartialDivisionAlgebraNative constructor( scope = scope, tensorHandle = copy_to_int(this.tensorHandle)!! ) - } public sealed class TorchTensorOverFieldNative constructor( diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/BenchmarkMatMul.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/BenchmarkMatMul.kt index c3ba10dc0..0457413d9 100644 --- a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/BenchmarkMatMul.kt +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/BenchmarkMatMul.kt @@ -2,6 +2,7 @@ package kscience.kmath.torch import kotlin.test.Test + internal class BenchmarkMatMul { @Test diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestAutograd.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestAutograd.kt index 3939d5e93..5750b70ff 100644 --- a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestAutograd.kt +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestAutograd.kt @@ -1,6 +1,6 @@ package kscience.kmath.torch -import kotlin.test.* +import kotlin.test.Test internal class TestAutograd { diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt index d05de3843..c5f4d5d34 100644 --- a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt @@ -1,6 +1,6 @@ package kscience.kmath.torch -import kotlin.test.* +import kotlin.test.Test internal class TestTorchTensorAlgebra {