JVM implementation

This commit is contained in:
Roland Grinis 2021-01-20 16:32:47 +00:00
parent c9dfb6a08c
commit c141c04e99
14 changed files with 691 additions and 51 deletions

View File

@ -12,9 +12,9 @@ internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
val shape = intArrayOf(2, 3, 4)
val tensor = copyFromArray(array, shape = shape, device = device)
val copyOfTensor = tensor.copy()
tensor[intArrayOf(0, 0)] = 0.1f
tensor[intArrayOf(1, 2, 3)] = 0.1f
assertTrue(copyOfTensor.copyToArray() contentEquals array)
assertEquals(0.1f, tensor[intArrayOf(0, 0)])
assertEquals(0.1f, tensor[intArrayOf(1, 2, 3)])
if(device != Device.CPU){
val normalCpu = randNormal(intArrayOf(2, 3))
val normalGpu = normalCpu.copyToDevice(device)

View File

@ -0,0 +1,386 @@
package kscience.kmath.torch
import kscience.kmath.memory.DeferScope
import kscience.kmath.memory.withDeferScope
public sealed class TorchTensorAlgebraJVM<
T,
PrimitiveArrayType,
TorchTensorType : TorchTensorJVM<T>> constructor(
internal val scope: DeferScope
) : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType> {
override fun getNumThreads(): Int {
return JTorch.getNumThreads()
}
override fun setNumThreads(numThreads: Int): Unit {
JTorch.setNumThreads(numThreads)
}
override fun cudaAvailable(): Boolean {
return JTorch.cudaIsAvailable()
}
override fun setSeed(seed: Int): Unit {
JTorch.setSeed(seed)
}
override var checks: Boolean = false
internal abstract fun wrap(tensorHandle: Long): TorchTensorType
override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, other)
return wrap(JTorch.timesTensor(this.tensorHandle, other.tensorHandle))
}
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
if (checks) checkLinearOperation(this, other)
JTorch.timesTensorAssign(this.tensorHandle, other.tensorHandle)
}
override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, other)
return wrap(JTorch.plusTensor(this.tensorHandle, other.tensorHandle))
}
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
if (checks) checkLinearOperation(this, other)
JTorch.plusTensorAssign(this.tensorHandle, other.tensorHandle)
}
override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, other)
return wrap(JTorch.minusTensor(this.tensorHandle, other.tensorHandle))
}
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
if (checks) checkLinearOperation(this, other)
JTorch.minusTensorAssign(this.tensorHandle, other.tensorHandle)
}
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
wrap(JTorch.unaryMinus(this.tensorHandle))
override infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType {
if (checks) checkDotOperation(this, other)
return wrap(JTorch.matmul(this.tensorHandle, other.tensorHandle))
}
override infix fun TorchTensorType.dotAssign(other: TorchTensorType): Unit {
if (checks) checkDotOperation(this, other)
JTorch.matmulAssign(this.tensorHandle, other.tensorHandle)
}
override infix fun TorchTensorType.dotRightAssign(other: TorchTensorType): Unit {
if (checks) checkDotOperation(this, other)
JTorch.matmulRightAssign(this.tensorHandle, other.tensorHandle)
}
override fun diagonalEmbedding(
diagonalEntries: TorchTensorType, offset: Int, dim1: Int, dim2: Int
): TorchTensorType =
wrap(JTorch.diagEmbed(diagonalEntries.tensorHandle, offset, dim1, dim2))
override fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType {
if (checks) checkTranspose(this.dimension, i, j)
return wrap(JTorch.transposeTensor(tensorHandle, i, j))
}
override fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
if (checks) checkTranspose(this.dimension, i, j)
JTorch.transposeTensorAssign(tensorHandle, i, j)
}
override fun TorchTensorType.view(shape: IntArray): TorchTensorType {
if (checks) checkView(this, shape)
return wrap(JTorch.viewTensor(this.tensorHandle, shape))
}
override fun TorchTensorType.abs(): TorchTensorType = wrap(JTorch.absTensor(tensorHandle))
override fun TorchTensorType.absAssign(): Unit = JTorch.absTensorAssign(tensorHandle)
override fun TorchTensorType.sum(): TorchTensorType = wrap(JTorch.sumTensor(tensorHandle))
override fun TorchTensorType.sumAssign(): Unit = JTorch.sumTensorAssign(tensorHandle)
override fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType =
wrap(JTorch.randintLike(this.tensorHandle, low, high))
override fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit =
JTorch.randintLikeAssign(this.tensorHandle, low, high)
override fun TorchTensorType.copy(): TorchTensorType =
wrap(JTorch.copyTensor(this.tensorHandle))
override fun TorchTensorType.copyToDevice(device: Device): TorchTensorType =
wrap(JTorch.copyToDevice(this.tensorHandle, device.toInt()))
override infix fun TorchTensorType.swap(other: TorchTensorType): Unit =
JTorch.swapTensors(this.tensorHandle, other.tensorHandle)
}
public sealed class TorchTensorPartialDivisionAlgebraJVM<T, PrimitiveArrayType,
TorchTensorType : TorchTensorOverFieldJVM<T>>(scope: DeferScope) :
TorchTensorAlgebraJVM<T, PrimitiveArrayType, TorchTensorType>(scope),
TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType> {
override operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType {
if (checks) checkLinearOperation(this, other)
return wrap(JTorch.divTensor(this.tensorHandle, other.tensorHandle))
}
override operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
if (checks) checkLinearOperation(this, other)
JTorch.divTensorAssign(this.tensorHandle, other.tensorHandle)
}
override fun TorchTensorType.randUniform(): TorchTensorType =
wrap(JTorch.randLike(this.tensorHandle))
override fun TorchTensorType.randUniformAssign(): Unit =
JTorch.randLikeAssign(this.tensorHandle)
override fun TorchTensorType.randNormal(): TorchTensorType =
wrap(JTorch.randnLike(this.tensorHandle))
override fun TorchTensorType.randNormalAssign(): Unit =
JTorch.randnLikeAssign(this.tensorHandle)
override fun TorchTensorType.exp(): TorchTensorType = wrap(JTorch.expTensor(tensorHandle))
override fun TorchTensorType.expAssign(): Unit = JTorch.expTensorAssign(tensorHandle)
override fun TorchTensorType.log(): TorchTensorType = wrap(JTorch.logTensor(tensorHandle))
override fun TorchTensorType.logAssign(): Unit = JTorch.logTensorAssign(tensorHandle)
override fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType> {
val U = JTorch.emptyTensor()
val V = JTorch.emptyTensor()
val S = JTorch.emptyTensor()
JTorch.svdTensor(this.tensorHandle, U, S, V)
return Triple(wrap(U), wrap(S), wrap(V))
}
override fun TorchTensorType.symEig(eigenvectors: Boolean): Pair<TorchTensorType, TorchTensorType> {
val V = JTorch.emptyTensor()
val S = JTorch.emptyTensor()
JTorch.symeigTensor(this.tensorHandle, S, V, eigenvectors)
return Pair(wrap(S), wrap(V))
}
override fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean): TorchTensorType {
if (checks) this.checkIsValue()
return wrap(JTorch.autogradTensor(this.tensorHandle, variable.tensorHandle, retainGraph))
}
override infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType {
if (checks) this.checkIsValue()
return wrap(JTorch.autohessTensor(this.tensorHandle, variable.tensorHandle))
}
override fun TorchTensorType.detachFromGraph(): TorchTensorType =
wrap(JTorch.detachFromGraph(this.tensorHandle))
}
public class TorchTensorRealAlgebra(scope: DeferScope) :
TorchTensorPartialDivisionAlgebraJVM<Double, DoubleArray, TorchTensorReal>(scope) {
override fun wrap(tensorHandle: Long): TorchTensorReal =
TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
override fun TorchTensorReal.copyToArray(): DoubleArray =
this.elements().map { it.second }.toList().toDoubleArray()
override fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): TorchTensorReal =
wrap(JTorch.fromBlobDouble(array, shape, device.toInt()))
override fun randNormal(shape: IntArray, device: Device): TorchTensorReal =
wrap(JTorch.randnDouble(shape, device.toInt()))
override fun randUniform(shape: IntArray, device: Device): TorchTensorReal =
wrap(JTorch.randDouble(shape, device.toInt()))
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorReal =
wrap(JTorch.randintDouble(low, high, shape, device.toInt()))
override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal =
wrap(JTorch.plusDouble(this, other.tensorHandle))
override fun TorchTensorReal.plus(value: Double): TorchTensorReal =
wrap(JTorch.plusDouble(value, this.tensorHandle))
override fun TorchTensorReal.plusAssign(value: Double): Unit =
JTorch.plusDoubleAssign(value, this.tensorHandle)
override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal =
wrap(JTorch.plusDouble(-this, other.tensorHandle))
override fun TorchTensorReal.minus(value: Double): TorchTensorReal =
wrap(JTorch.plusDouble(-value, this.tensorHandle))
override fun TorchTensorReal.minusAssign(value: Double): Unit =
JTorch.plusDoubleAssign(-value, this.tensorHandle)
override operator fun Double.times(other: TorchTensorReal): TorchTensorReal =
wrap(JTorch.timesDouble(this, other.tensorHandle))
override fun TorchTensorReal.times(value: Double): TorchTensorReal =
wrap(JTorch.timesDouble(value, this.tensorHandle))
override fun TorchTensorReal.timesAssign(value: Double): Unit =
JTorch.timesDoubleAssign(value, this.tensorHandle)
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
wrap(JTorch.fullDouble(value, shape, device.toInt()))
}
public class TorchTensorFloatAlgebra(scope: DeferScope) :
TorchTensorPartialDivisionAlgebraJVM<Float, FloatArray, TorchTensorFloat>(scope) {
override fun wrap(tensorHandle: Long): TorchTensorFloat =
TorchTensorFloat(scope = scope, tensorHandle = tensorHandle)
override fun TorchTensorFloat.copyToArray(): FloatArray =
this.elements().map { it.second }.toList().toFloatArray()
override fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): TorchTensorFloat =
wrap(JTorch.fromBlobFloat(array, shape, device.toInt()))
override fun randNormal(shape: IntArray, device: Device): TorchTensorFloat =
wrap(JTorch.randnFloat(shape, device.toInt()))
override fun randUniform(shape: IntArray, device: Device): TorchTensorFloat =
wrap(JTorch.randFloat(shape, device.toInt()))
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorFloat =
wrap(JTorch.randintFloat(low, high, shape, device.toInt()))
override operator fun Float.plus(other: TorchTensorFloat): TorchTensorFloat =
wrap(JTorch.plusFloat(this, other.tensorHandle))
override fun TorchTensorFloat.plus(value: Float): TorchTensorFloat =
wrap(JTorch.plusFloat(value, this.tensorHandle))
override fun TorchTensorFloat.plusAssign(value: Float): Unit =
JTorch.plusFloatAssign(value, this.tensorHandle)
override operator fun Float.minus(other: TorchTensorFloat): TorchTensorFloat =
wrap(JTorch.plusFloat(-this, other.tensorHandle))
override fun TorchTensorFloat.minus(value: Float): TorchTensorFloat =
wrap(JTorch.plusFloat(-value, this.tensorHandle))
override fun TorchTensorFloat.minusAssign(value: Float): Unit =
JTorch.plusFloatAssign(-value, this.tensorHandle)
override operator fun Float.times(other: TorchTensorFloat): TorchTensorFloat =
wrap(JTorch.timesFloat(this, other.tensorHandle))
override fun TorchTensorFloat.times(value: Float): TorchTensorFloat =
wrap(JTorch.timesFloat(value, this.tensorHandle))
override fun TorchTensorFloat.timesAssign(value: Float): Unit =
JTorch.timesFloatAssign(value, this.tensorHandle)
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
wrap(JTorch.fullFloat(value, shape, device.toInt()))
}
public class TorchTensorLongAlgebra(scope: DeferScope) :
TorchTensorAlgebraJVM<Long, LongArray, TorchTensorLong>(scope) {
override fun wrap(tensorHandle: Long): TorchTensorLong =
TorchTensorLong(scope = scope, tensorHandle = tensorHandle)
override fun TorchTensorLong.copyToArray(): LongArray =
this.elements().map { it.second }.toList().toLongArray()
override fun copyFromArray(array: LongArray, shape: IntArray, device: Device): TorchTensorLong =
wrap(JTorch.fromBlobLong(array, shape, device.toInt()))
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorLong =
wrap(JTorch.randintLong(low, high, shape, device.toInt()))
override operator fun Long.plus(other: TorchTensorLong): TorchTensorLong =
wrap(JTorch.plusLong(this, other.tensorHandle))
override fun TorchTensorLong.plus(value: Long): TorchTensorLong =
wrap(JTorch.plusLong(value, this.tensorHandle))
override fun TorchTensorLong.plusAssign(value: Long): Unit =
JTorch.plusLongAssign(value, this.tensorHandle)
override operator fun Long.minus(other: TorchTensorLong): TorchTensorLong =
wrap(JTorch.plusLong(-this, other.tensorHandle))
override fun TorchTensorLong.minus(value: Long): TorchTensorLong =
wrap(JTorch.plusLong(-value, this.tensorHandle))
override fun TorchTensorLong.minusAssign(value: Long): Unit =
JTorch.plusLongAssign(-value, this.tensorHandle)
override operator fun Long.times(other: TorchTensorLong): TorchTensorLong =
wrap(JTorch.timesLong(this, other.tensorHandle))
override fun TorchTensorLong.times(value: Long): TorchTensorLong =
wrap(JTorch.timesLong(value, this.tensorHandle))
override fun TorchTensorLong.timesAssign(value: Long): Unit =
JTorch.timesLongAssign(value, this.tensorHandle)
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
wrap(JTorch.fullLong(value, shape, device.toInt()))
}
public class TorchTensorIntAlgebra(scope: DeferScope) :
TorchTensorAlgebraJVM<Int, IntArray, TorchTensorInt>(scope) {
override fun wrap(tensorHandle: Long): TorchTensorInt =
TorchTensorInt(scope = scope, tensorHandle = tensorHandle)
override fun TorchTensorInt.copyToArray(): IntArray =
this.elements().map { it.second }.toList().toIntArray()
override fun copyFromArray(array: IntArray, shape: IntArray, device: Device): TorchTensorInt =
wrap(JTorch.fromBlobInt(array, shape, device.toInt()))
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorInt =
wrap(JTorch.randintInt(low, high, shape, device.toInt()))
override operator fun Int.plus(other: TorchTensorInt): TorchTensorInt =
wrap(JTorch.plusInt(this, other.tensorHandle))
override fun TorchTensorInt.plus(value: Int): TorchTensorInt =
wrap(JTorch.plusInt(value, this.tensorHandle))
override fun TorchTensorInt.plusAssign(value: Int): Unit =
JTorch.plusIntAssign(value, this.tensorHandle)
override operator fun Int.minus(other: TorchTensorInt): TorchTensorInt =
wrap(JTorch.plusInt(-this, other.tensorHandle))
override fun TorchTensorInt.minus(value: Int): TorchTensorInt =
wrap(JTorch.plusInt(-value, this.tensorHandle))
override fun TorchTensorInt.minusAssign(value: Int): Unit =
JTorch.plusIntAssign(-value, this.tensorHandle)
override operator fun Int.times(other: TorchTensorInt): TorchTensorInt =
wrap(JTorch.timesInt(this, other.tensorHandle))
override fun TorchTensorInt.times(value: Int): TorchTensorInt =
wrap(JTorch.timesInt(value, this.tensorHandle))
override fun TorchTensorInt.timesAssign(value: Int): Unit =
JTorch.timesIntAssign(value, this.tensorHandle)
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
wrap(JTorch.fullInt(value, shape, device.toInt()))
}
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
withDeferScope { TorchTensorRealAlgebra(this).block() }
public inline fun <R> TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R =
withDeferScope { TorchTensorFloatAlgebra(this).block() }
public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R =
withDeferScope { TorchTensorLongAlgebra(this).block() }
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
withDeferScope { TorchTensorIntAlgebra(this).block() }

View File

@ -1,4 +1,94 @@
package kscience.kmath.torch
public class TorchTensorJVM {
import kscience.kmath.memory.DeferScope
public sealed class TorchTensorJVM<T> constructor(
scope: DeferScope,
internal val tensorHandle: Long
) : TorchTensor<T>, TorchTensorMemoryHolder(scope)
{
override fun close(): Unit = JTorch.disposeTensor(tensorHandle)
override val dimension: Int get() = JTorch.getDim(tensorHandle)
override val shape: IntArray
get() = (1..dimension).map { JTorch.getShapeAt(tensorHandle, it - 1) }.toIntArray()
override val strides: IntArray
get() = (1..dimension).map { JTorch.getStrideAt(tensorHandle, it - 1) }.toIntArray()
override val size: Int get() = JTorch.getNumel(tensorHandle)
override val device: Device get() = Device.fromInt(JTorch.getDevice(tensorHandle))
override fun toString(): String = JTorch.tensorToString(tensorHandle)
public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = JTorch.copyToDouble(this.tensorHandle)
)
public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat(
scope = scope,
tensorHandle = JTorch.copyToFloat(this.tensorHandle)
)
public fun copyToLong(): TorchTensorLong = TorchTensorLong(
scope = scope,
tensorHandle = JTorch.copyToLong(this.tensorHandle)
)
public fun copyToInt(): TorchTensorInt = TorchTensorInt(
scope = scope,
tensorHandle = JTorch.copyToInt(this.tensorHandle)
)
}
public sealed class TorchTensorOverFieldJVM<T> constructor(
scope: DeferScope,
tensorHandle: Long
) : TorchTensorJVM<T>(scope, tensorHandle), TorchTensorOverField<T> {
override var requiresGrad: Boolean
get() = JTorch.requiresGrad(tensorHandle)
set(value) = JTorch.setRequiresGrad(tensorHandle, value)
}
public class TorchTensorReal internal constructor(
scope: DeferScope,
tensorHandle: Long
) : TorchTensorOverFieldJVM<Double>(scope, tensorHandle) {
override fun item(): Double = JTorch.getItemDouble(tensorHandle)
override fun get(index: IntArray): Double = JTorch.getDouble(tensorHandle, index)
override fun set(index: IntArray, value: Double) {
JTorch.setDouble(tensorHandle, index, value)
}
}
public class TorchTensorFloat internal constructor(
scope: DeferScope,
tensorHandle: Long
) : TorchTensorOverFieldJVM<Float>(scope, tensorHandle) {
override fun item(): Float = JTorch.getItemFloat(tensorHandle)
override fun get(index: IntArray): Float = JTorch.getFloat(tensorHandle, index)
override fun set(index: IntArray, value: Float) {
JTorch.setFloat(tensorHandle, index, value)
}
}
public class TorchTensorLong internal constructor(
scope: DeferScope,
tensorHandle: Long
) : TorchTensorOverFieldJVM<Long>(scope, tensorHandle) {
override fun item(): Long = JTorch.getItemLong(tensorHandle)
override fun get(index: IntArray): Long = JTorch.getLong(tensorHandle, index)
override fun set(index: IntArray, value: Long) {
JTorch.setLong(tensorHandle, index, value)
}
}
public class TorchTensorInt internal constructor(
scope: DeferScope,
tensorHandle: Long
) : TorchTensorOverFieldJVM<Int>(scope, tensorHandle) {
override fun item(): Int = JTorch.getItemInt(tensorHandle)
override fun get(index: IntArray): Int = JTorch.getInt(tensorHandle, index)
override fun set(index: IntArray, value: Int) {
JTorch.setInt(tensorHandle, index, value)
}
}

View File

@ -0,0 +1,26 @@
package kscience.kmath.torch
import kotlin.test.Test
class BenchmarkMatMul {
@Test
fun benchmarkMatMulDouble() = TorchTensorRealAlgebra {
benchmarkMatMul(20, 10, 100000, "Real")
benchmarkMatMul(200, 10, 10000, "Real")
benchmarkMatMul(2000, 3, 20, "Real")
}
@Test
fun benchmarkMatMulFloat() = TorchTensorFloatAlgebra {
benchmarkMatMul(20, 10, 100000, "Float")
benchmarkMatMul(200, 10, 10000, "Float")
benchmarkMatMul(2000, 3, 20, "Float")
if (cudaAvailable()) {
benchmarkMatMul(20, 10, 100000, "Float", Device.CUDA(0))
benchmarkMatMul(200, 10, 10000, "Float", Device.CUDA(0))
benchmarkMatMul(2000, 10, 1000, "Float", Device.CUDA(0))
}
}
}

View File

@ -0,0 +1,27 @@
package kscience.kmath.torch
import kotlin.test.Test
class BenchmarkRandomGenerators {
@Test
fun benchmarkRand1() = TorchTensorFloatAlgebra{
benchmarkingRand1()
}
@Test
fun benchmarkRand3() = TorchTensorFloatAlgebra{
benchmarkingRand3()
}
@Test
fun benchmarkRand5() = TorchTensorFloatAlgebra{
benchmarkingRand5()
}
@Test
fun benchmarkRand7() = TorchTensorFloatAlgebra{
benchmarkingRand7()
}
}

View File

@ -0,0 +1,24 @@
package kscience.kmath.torch
import kotlin.test.Test
class TestAutograd {
@Test
fun testAutoGrad() = TorchTensorFloatAlgebra {
withChecks {
withCuda { device ->
testingAutoGrad(device)
}
}
}
@Test
fun testBatchedAutoGrad() = TorchTensorFloatAlgebra {
withChecks {
withCuda { device ->
testingBatchedAutoGrad(device)
}
}
}
}

View File

@ -0,0 +1,39 @@
package kscience.kmath.torch
import kotlin.test.*
class TestTorchTensor {
@Test
fun testCopying() = TorchTensorFloatAlgebra {
withCuda { device ->
testingCopying(device)
}
}
@Test
fun testRequiresGrad() = TorchTensorRealAlgebra {
testingRequiresGrad()
}
@Test
fun testTypeMoving() = TorchTensorFloatAlgebra {
val tensorInt = copyFromArray(floatArrayOf(1f, 2f, 3f), intArrayOf(3)).copyToInt()
TorchTensorIntAlgebra {
val temporalTensor = copyFromArray(intArrayOf(4, 5, 6), intArrayOf(3))
tensorInt swap temporalTensor
assertTrue(temporalTensor.copyToArray() contentEquals intArrayOf(1, 2, 3))
}
assertTrue(tensorInt.copyToFloat().copyToArray() contentEquals floatArrayOf(4f, 5f, 6f))
}
@Test
fun testViewWithNoCopy() = TorchTensorIntAlgebra {
withChecks {
withCuda {
device -> testingViewWithNoCopy(device)
}
}
}
}

View File

@ -0,0 +1,63 @@
package kscience.kmath.torch
import kotlin.test.Test
class TestTorchTensorAlgebra {
@Test
fun testScalarProduct() = TorchTensorRealAlgebra {
withChecks {
withCuda { device ->
testingScalarProduct(device)
}
}
}
@Test
fun testMatrixMultiplication() = TorchTensorRealAlgebra {
withChecks {
withCuda { device ->
testingMatrixMultiplication(device)
}
}
}
@Test
fun testLinearStructure() = TorchTensorRealAlgebra {
withChecks {
withCuda { device ->
testingLinearStructure(device)
}
}
}
@Test
fun testTensorTransformations() = TorchTensorRealAlgebra {
withChecks {
withCuda { device ->
testingTensorTransformations(device)
}
}
}
@Test
fun testBatchedSVD() = TorchTensorRealAlgebra {
withChecks {
withCuda { device ->
testingBatchedSVD(device)
}
}
}
@Test
fun testBatchedSymEig() = TorchTensorRealAlgebra {
withChecks {
withCuda { device ->
testingBatchedSymEig(device)
}
}
}
}

View File

@ -6,10 +6,16 @@ import kotlin.test.*
class TestUtils {
@Test
fun testJTorch() {
val tensor = JTorch.fullInt(54, intArrayOf(3), 0)
println(JTorch.tensorToString(tensor))
JTorch.disposeTensor(tensor)
fun testSetNumThreads() {
TorchTensorLongAlgebra {
testingSetNumThreads()
}
}
@Test
fun testSeedSetting() = TorchTensorFloatAlgebra {
withCuda { device ->
testingSetSeed(device)
}
}
}

View File

@ -31,7 +31,6 @@ public sealed class TorchTensorAlgebraNative<
set_seed(seed)
}
override var checks: Boolean = false
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensorType
@ -108,21 +107,16 @@ public sealed class TorchTensorAlgebraNative<
}
override fun TorchTensorType.abs(): TorchTensorType = wrap(abs_tensor(tensorHandle)!!)
override fun TorchTensorType.absAssign(): Unit {
abs_tensor_assign(tensorHandle)
}
override fun TorchTensorType.absAssign(): Unit = abs_tensor_assign(tensorHandle)
override fun TorchTensorType.sum(): TorchTensorType = wrap(sum_tensor(tensorHandle)!!)
override fun TorchTensorType.sumAssign(): Unit {
sum_tensor_assign(tensorHandle)
}
override fun TorchTensorType.sumAssign(): Unit = sum_tensor_assign(tensorHandle)
override fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType =
wrap(randint_like(this.tensorHandle, low, high)!!)
override fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit {
override fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit =
randint_like_assign(this.tensorHandle, low, high)
}
override fun TorchTensorType.copy(): TorchTensorType =
wrap(copy_tensor(this.tensorHandle)!!)
@ -130,10 +124,9 @@ public sealed class TorchTensorAlgebraNative<
override fun TorchTensorType.copyToDevice(device: Device): TorchTensorType =
wrap(copy_to_device(this.tensorHandle, device.toInt())!!)
override infix fun TorchTensorType.swap(other: TorchTensorType): Unit {
override infix fun TorchTensorType.swap(other: TorchTensorType): Unit =
swap_tensors(this.tensorHandle, other.tensorHandle)
}
}
public sealed class TorchTensorPartialDivisionAlgebraNative<T, TVar : CPrimitiveVar,
PrimitiveArrayType, TorchTensorType : TorchTensorOverFieldNative<T>>(scope: DeferScope) :
@ -153,26 +146,21 @@ public sealed class TorchTensorPartialDivisionAlgebraNative<T, TVar : CPrimitive
override fun TorchTensorType.randUniform(): TorchTensorType =
wrap(rand_like(this.tensorHandle)!!)
override fun TorchTensorType.randUniformAssign(): Unit {
override fun TorchTensorType.randUniformAssign(): Unit =
rand_like_assign(this.tensorHandle)
}
override fun TorchTensorType.randNormal(): TorchTensorType =
wrap(randn_like(this.tensorHandle)!!)
override fun TorchTensorType.randNormalAssign(): Unit {
override fun TorchTensorType.randNormalAssign(): Unit =
randn_like_assign(this.tensorHandle)
}
override fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
override fun TorchTensorType.expAssign(): Unit {
exp_tensor_assign(tensorHandle)
}
override fun TorchTensorType.expAssign(): Unit = exp_tensor_assign(tensorHandle)
override fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!)
override fun TorchTensorType.logAssign(): Unit {
log_tensor_assign(tensorHandle)
}
override fun TorchTensorType.logAssign(): Unit = log_tensor_assign(tensorHandle)
override fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType> {
val U = empty_tensor()!!
@ -200,7 +188,7 @@ public sealed class TorchTensorPartialDivisionAlgebraNative<T, TVar : CPrimitive
}
override fun TorchTensorType.detachFromGraph(): TorchTensorType =
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
wrap(detach_from_graph(this.tensorHandle)!!)
}
@ -305,9 +293,8 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
override fun TorchTensorFloat.plus(value: Float): TorchTensorFloat =
wrap(plus_float(value, this.tensorHandle)!!)
override fun TorchTensorFloat.plusAssign(value: Float): Unit {
override fun TorchTensorFloat.plusAssign(value: Float): Unit =
plus_float_assign(value, this.tensorHandle)
}
override operator fun Float.minus(other: TorchTensorFloat): TorchTensorFloat =
wrap(plus_float(-this, other.tensorHandle)!!)
@ -315,9 +302,8 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
override fun TorchTensorFloat.minus(value: Float): TorchTensorFloat =
wrap(plus_float(-value, this.tensorHandle)!!)
override fun TorchTensorFloat.minusAssign(value: Float): Unit {
override fun TorchTensorFloat.minusAssign(value: Float): Unit =
plus_float_assign(-value, this.tensorHandle)
}
override operator fun Float.times(other: TorchTensorFloat): TorchTensorFloat =
wrap(times_float(this, other.tensorHandle)!!)
@ -325,9 +311,8 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
override fun TorchTensorFloat.times(value: Float): TorchTensorFloat =
wrap(times_float(value, this.tensorHandle)!!)
override fun TorchTensorFloat.timesAssign(value: Float): Unit {
override fun TorchTensorFloat.timesAssign(value: Float): Unit =
times_float_assign(value, this.tensorHandle)
}
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
@ -364,9 +349,8 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
override fun TorchTensorLong.plus(value: Long): TorchTensorLong =
wrap(plus_long(value, this.tensorHandle)!!)
override fun TorchTensorLong.plusAssign(value: Long): Unit {
override fun TorchTensorLong.plusAssign(value: Long): Unit =
plus_long_assign(value, this.tensorHandle)
}
override operator fun Long.minus(other: TorchTensorLong): TorchTensorLong =
wrap(plus_long(-this, other.tensorHandle)!!)
@ -374,9 +358,8 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
override fun TorchTensorLong.minus(value: Long): TorchTensorLong =
wrap(plus_long(-value, this.tensorHandle)!!)
override fun TorchTensorLong.minusAssign(value: Long): Unit {
override fun TorchTensorLong.minusAssign(value: Long): Unit =
plus_long_assign(-value, this.tensorHandle)
}
override operator fun Long.times(other: TorchTensorLong): TorchTensorLong =
wrap(times_long(this, other.tensorHandle)!!)
@ -384,9 +367,8 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
override fun TorchTensorLong.times(value: Long): TorchTensorLong =
wrap(times_long(value, this.tensorHandle)!!)
override fun TorchTensorLong.timesAssign(value: Long): Unit {
override fun TorchTensorLong.timesAssign(value: Long): Unit =
times_long_assign(value, this.tensorHandle)
}
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
@ -422,9 +404,8 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
override fun TorchTensorInt.plus(value: Int): TorchTensorInt =
wrap(plus_int(value, this.tensorHandle)!!)
override fun TorchTensorInt.plusAssign(value: Int): Unit {
override fun TorchTensorInt.plusAssign(value: Int): Unit =
plus_int_assign(value, this.tensorHandle)
}
override operator fun Int.minus(other: TorchTensorInt): TorchTensorInt =
wrap(plus_int(-this, other.tensorHandle)!!)
@ -432,9 +413,8 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
override fun TorchTensorInt.minus(value: Int): TorchTensorInt =
wrap(plus_int(-value, this.tensorHandle)!!)
override fun TorchTensorInt.minusAssign(value: Int): Unit {
override fun TorchTensorInt.minusAssign(value: Int): Unit =
plus_int_assign(-value, this.tensorHandle)
}
override operator fun Int.times(other: TorchTensorInt): TorchTensorInt =
wrap(times_int(this, other.tensorHandle)!!)
@ -442,9 +422,8 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
override fun TorchTensorInt.times(value: Int): TorchTensorInt =
wrap(times_int(value, this.tensorHandle)!!)
override fun TorchTensorInt.timesAssign(value: Int): Unit {
override fun TorchTensorInt.timesAssign(value: Int): Unit =
times_int_assign(value, this.tensorHandle)
}
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)

View File

@ -47,7 +47,6 @@ public sealed class TorchTensorNative<T> constructor(
scope = scope,
tensorHandle = copy_to_int(this.tensorHandle)!!
)
}
public sealed class TorchTensorOverFieldNative<T> constructor(

View File

@ -2,6 +2,7 @@ package kscience.kmath.torch
import kotlin.test.Test
internal class BenchmarkMatMul {
@Test

View File

@ -1,6 +1,6 @@
package kscience.kmath.torch
import kotlin.test.*
import kotlin.test.Test
internal class TestAutograd {

View File

@ -1,6 +1,6 @@
package kscience.kmath.torch
import kotlin.test.*
import kotlin.test.Test
internal class TestTorchTensorAlgebra {