forked from kscience/kmath
JVM implementation
This commit is contained in:
parent
c9dfb6a08c
commit
c141c04e99
@ -12,9 +12,9 @@ internal inline fun <TorchTensorType : TorchTensorOverField<Float>,
|
||||
val shape = intArrayOf(2, 3, 4)
|
||||
val tensor = copyFromArray(array, shape = shape, device = device)
|
||||
val copyOfTensor = tensor.copy()
|
||||
tensor[intArrayOf(0, 0)] = 0.1f
|
||||
tensor[intArrayOf(1, 2, 3)] = 0.1f
|
||||
assertTrue(copyOfTensor.copyToArray() contentEquals array)
|
||||
assertEquals(0.1f, tensor[intArrayOf(0, 0)])
|
||||
assertEquals(0.1f, tensor[intArrayOf(1, 2, 3)])
|
||||
if(device != Device.CPU){
|
||||
val normalCpu = randNormal(intArrayOf(2, 3))
|
||||
val normalGpu = normalCpu.copyToDevice(device)
|
||||
|
@ -0,0 +1,386 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kscience.kmath.memory.DeferScope
|
||||
import kscience.kmath.memory.withDeferScope
|
||||
|
||||
public sealed class TorchTensorAlgebraJVM<
|
||||
T,
|
||||
PrimitiveArrayType,
|
||||
TorchTensorType : TorchTensorJVM<T>> constructor(
|
||||
internal val scope: DeferScope
|
||||
) : TorchTensorAlgebra<T, PrimitiveArrayType, TorchTensorType> {
|
||||
override fun getNumThreads(): Int {
|
||||
return JTorch.getNumThreads()
|
||||
}
|
||||
|
||||
override fun setNumThreads(numThreads: Int): Unit {
|
||||
JTorch.setNumThreads(numThreads)
|
||||
}
|
||||
|
||||
override fun cudaAvailable(): Boolean {
|
||||
return JTorch.cudaIsAvailable()
|
||||
}
|
||||
|
||||
override fun setSeed(seed: Int): Unit {
|
||||
JTorch.setSeed(seed)
|
||||
}
|
||||
|
||||
override var checks: Boolean = false
|
||||
|
||||
internal abstract fun wrap(tensorHandle: Long): TorchTensorType
|
||||
|
||||
override operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
return wrap(JTorch.timesTensor(this.tensorHandle, other.tensorHandle))
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
JTorch.timesTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
return wrap(JTorch.plusTensor(this.tensorHandle, other.tensorHandle))
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
JTorch.plusTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
return wrap(JTorch.minusTensor(this.tensorHandle, other.tensorHandle))
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
JTorch.minusTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
||||
wrap(JTorch.unaryMinus(this.tensorHandle))
|
||||
|
||||
override infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkDotOperation(this, other)
|
||||
return wrap(JTorch.matmul(this.tensorHandle, other.tensorHandle))
|
||||
}
|
||||
|
||||
override infix fun TorchTensorType.dotAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkDotOperation(this, other)
|
||||
JTorch.matmulAssign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override infix fun TorchTensorType.dotRightAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkDotOperation(this, other)
|
||||
JTorch.matmulRightAssign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override fun diagonalEmbedding(
|
||||
diagonalEntries: TorchTensorType, offset: Int, dim1: Int, dim2: Int
|
||||
): TorchTensorType =
|
||||
wrap(JTorch.diagEmbed(diagonalEntries.tensorHandle, offset, dim1, dim2))
|
||||
|
||||
override fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType {
|
||||
if (checks) checkTranspose(this.dimension, i, j)
|
||||
return wrap(JTorch.transposeTensor(tensorHandle, i, j))
|
||||
}
|
||||
|
||||
override fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
|
||||
if (checks) checkTranspose(this.dimension, i, j)
|
||||
JTorch.transposeTensorAssign(tensorHandle, i, j)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.view(shape: IntArray): TorchTensorType {
|
||||
if (checks) checkView(this, shape)
|
||||
return wrap(JTorch.viewTensor(this.tensorHandle, shape))
|
||||
}
|
||||
|
||||
override fun TorchTensorType.abs(): TorchTensorType = wrap(JTorch.absTensor(tensorHandle))
|
||||
override fun TorchTensorType.absAssign(): Unit = JTorch.absTensorAssign(tensorHandle)
|
||||
|
||||
override fun TorchTensorType.sum(): TorchTensorType = wrap(JTorch.sumTensor(tensorHandle))
|
||||
override fun TorchTensorType.sumAssign(): Unit = JTorch.sumTensorAssign(tensorHandle)
|
||||
|
||||
override fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType =
|
||||
wrap(JTorch.randintLike(this.tensorHandle, low, high))
|
||||
|
||||
override fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit =
|
||||
JTorch.randintLikeAssign(this.tensorHandle, low, high)
|
||||
|
||||
override fun TorchTensorType.copy(): TorchTensorType =
|
||||
wrap(JTorch.copyTensor(this.tensorHandle))
|
||||
|
||||
override fun TorchTensorType.copyToDevice(device: Device): TorchTensorType =
|
||||
wrap(JTorch.copyToDevice(this.tensorHandle, device.toInt()))
|
||||
|
||||
override infix fun TorchTensorType.swap(other: TorchTensorType): Unit =
|
||||
JTorch.swapTensors(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
public sealed class TorchTensorPartialDivisionAlgebraJVM<T, PrimitiveArrayType,
|
||||
TorchTensorType : TorchTensorOverFieldJVM<T>>(scope: DeferScope) :
|
||||
TorchTensorAlgebraJVM<T, PrimitiveArrayType, TorchTensorType>(scope),
|
||||
TorchTensorPartialDivisionAlgebra<T, PrimitiveArrayType, TorchTensorType> {
|
||||
|
||||
override operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
return wrap(JTorch.divTensor(this.tensorHandle, other.tensorHandle))
|
||||
}
|
||||
|
||||
override operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
|
||||
if (checks) checkLinearOperation(this, other)
|
||||
JTorch.divTensorAssign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.randUniform(): TorchTensorType =
|
||||
wrap(JTorch.randLike(this.tensorHandle))
|
||||
|
||||
override fun TorchTensorType.randUniformAssign(): Unit =
|
||||
JTorch.randLikeAssign(this.tensorHandle)
|
||||
|
||||
override fun TorchTensorType.randNormal(): TorchTensorType =
|
||||
wrap(JTorch.randnLike(this.tensorHandle))
|
||||
|
||||
override fun TorchTensorType.randNormalAssign(): Unit =
|
||||
JTorch.randnLikeAssign(this.tensorHandle)
|
||||
|
||||
override fun TorchTensorType.exp(): TorchTensorType = wrap(JTorch.expTensor(tensorHandle))
|
||||
override fun TorchTensorType.expAssign(): Unit = JTorch.expTensorAssign(tensorHandle)
|
||||
override fun TorchTensorType.log(): TorchTensorType = wrap(JTorch.logTensor(tensorHandle))
|
||||
override fun TorchTensorType.logAssign(): Unit = JTorch.logTensorAssign(tensorHandle)
|
||||
|
||||
override fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType> {
|
||||
val U = JTorch.emptyTensor()
|
||||
val V = JTorch.emptyTensor()
|
||||
val S = JTorch.emptyTensor()
|
||||
JTorch.svdTensor(this.tensorHandle, U, S, V)
|
||||
return Triple(wrap(U), wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
override fun TorchTensorType.symEig(eigenvectors: Boolean): Pair<TorchTensorType, TorchTensorType> {
|
||||
val V = JTorch.emptyTensor()
|
||||
val S = JTorch.emptyTensor()
|
||||
JTorch.symeigTensor(this.tensorHandle, S, V, eigenvectors)
|
||||
return Pair(wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
override fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean): TorchTensorType {
|
||||
if (checks) this.checkIsValue()
|
||||
return wrap(JTorch.autogradTensor(this.tensorHandle, variable.tensorHandle, retainGraph))
|
||||
}
|
||||
|
||||
override infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType {
|
||||
if (checks) this.checkIsValue()
|
||||
return wrap(JTorch.autohessTensor(this.tensorHandle, variable.tensorHandle))
|
||||
}
|
||||
|
||||
override fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
||||
wrap(JTorch.detachFromGraph(this.tensorHandle))
|
||||
|
||||
}
|
||||
|
||||
public class TorchTensorRealAlgebra(scope: DeferScope) :
|
||||
TorchTensorPartialDivisionAlgebraJVM<Double, DoubleArray, TorchTensorReal>(scope) {
|
||||
override fun wrap(tensorHandle: Long): TorchTensorReal =
|
||||
TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun TorchTensorReal.copyToArray(): DoubleArray =
|
||||
this.elements().map { it.second }.toList().toDoubleArray()
|
||||
|
||||
override fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(JTorch.fromBlobDouble(array, shape, device.toInt()))
|
||||
|
||||
override fun randNormal(shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(JTorch.randnDouble(shape, device.toInt()))
|
||||
|
||||
override fun randUniform(shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(JTorch.randDouble(shape, device.toInt()))
|
||||
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(JTorch.randintDouble(low, high, shape, device.toInt()))
|
||||
|
||||
override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal =
|
||||
wrap(JTorch.plusDouble(this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorReal.plus(value: Double): TorchTensorReal =
|
||||
wrap(JTorch.plusDouble(value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorReal.plusAssign(value: Double): Unit =
|
||||
JTorch.plusDoubleAssign(value, this.tensorHandle)
|
||||
|
||||
override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal =
|
||||
wrap(JTorch.plusDouble(-this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorReal.minus(value: Double): TorchTensorReal =
|
||||
wrap(JTorch.plusDouble(-value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorReal.minusAssign(value: Double): Unit =
|
||||
JTorch.plusDoubleAssign(-value, this.tensorHandle)
|
||||
|
||||
override operator fun Double.times(other: TorchTensorReal): TorchTensorReal =
|
||||
wrap(JTorch.timesDouble(this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorReal.times(value: Double): TorchTensorReal =
|
||||
wrap(JTorch.timesDouble(value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorReal.timesAssign(value: Double): Unit =
|
||||
JTorch.timesDoubleAssign(value, this.tensorHandle)
|
||||
|
||||
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
|
||||
wrap(JTorch.fullDouble(value, shape, device.toInt()))
|
||||
}
|
||||
|
||||
public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||
TorchTensorPartialDivisionAlgebraJVM<Float, FloatArray, TorchTensorFloat>(scope) {
|
||||
override fun wrap(tensorHandle: Long): TorchTensorFloat =
|
||||
TorchTensorFloat(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun TorchTensorFloat.copyToArray(): FloatArray =
|
||||
this.elements().map { it.second }.toList().toFloatArray()
|
||||
|
||||
override fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(JTorch.fromBlobFloat(array, shape, device.toInt()))
|
||||
|
||||
override fun randNormal(shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(JTorch.randnFloat(shape, device.toInt()))
|
||||
|
||||
override fun randUniform(shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(JTorch.randFloat(shape, device.toInt()))
|
||||
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(JTorch.randintFloat(low, high, shape, device.toInt()))
|
||||
|
||||
override operator fun Float.plus(other: TorchTensorFloat): TorchTensorFloat =
|
||||
wrap(JTorch.plusFloat(this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorFloat.plus(value: Float): TorchTensorFloat =
|
||||
wrap(JTorch.plusFloat(value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorFloat.plusAssign(value: Float): Unit =
|
||||
JTorch.plusFloatAssign(value, this.tensorHandle)
|
||||
|
||||
override operator fun Float.minus(other: TorchTensorFloat): TorchTensorFloat =
|
||||
wrap(JTorch.plusFloat(-this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorFloat.minus(value: Float): TorchTensorFloat =
|
||||
wrap(JTorch.plusFloat(-value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorFloat.minusAssign(value: Float): Unit =
|
||||
JTorch.plusFloatAssign(-value, this.tensorHandle)
|
||||
|
||||
override operator fun Float.times(other: TorchTensorFloat): TorchTensorFloat =
|
||||
wrap(JTorch.timesFloat(this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorFloat.times(value: Float): TorchTensorFloat =
|
||||
wrap(JTorch.timesFloat(value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorFloat.timesAssign(value: Float): Unit =
|
||||
JTorch.timesFloatAssign(value, this.tensorHandle)
|
||||
|
||||
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(JTorch.fullFloat(value, shape, device.toInt()))
|
||||
}
|
||||
|
||||
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||
TorchTensorAlgebraJVM<Long, LongArray, TorchTensorLong>(scope) {
|
||||
override fun wrap(tensorHandle: Long): TorchTensorLong =
|
||||
TorchTensorLong(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun TorchTensorLong.copyToArray(): LongArray =
|
||||
this.elements().map { it.second }.toList().toLongArray()
|
||||
|
||||
override fun copyFromArray(array: LongArray, shape: IntArray, device: Device): TorchTensorLong =
|
||||
wrap(JTorch.fromBlobLong(array, shape, device.toInt()))
|
||||
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||
wrap(JTorch.randintLong(low, high, shape, device.toInt()))
|
||||
|
||||
override operator fun Long.plus(other: TorchTensorLong): TorchTensorLong =
|
||||
wrap(JTorch.plusLong(this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorLong.plus(value: Long): TorchTensorLong =
|
||||
wrap(JTorch.plusLong(value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorLong.plusAssign(value: Long): Unit =
|
||||
JTorch.plusLongAssign(value, this.tensorHandle)
|
||||
|
||||
override operator fun Long.minus(other: TorchTensorLong): TorchTensorLong =
|
||||
wrap(JTorch.plusLong(-this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorLong.minus(value: Long): TorchTensorLong =
|
||||
wrap(JTorch.plusLong(-value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorLong.minusAssign(value: Long): Unit =
|
||||
JTorch.plusLongAssign(-value, this.tensorHandle)
|
||||
|
||||
override operator fun Long.times(other: TorchTensorLong): TorchTensorLong =
|
||||
wrap(JTorch.timesLong(this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorLong.times(value: Long): TorchTensorLong =
|
||||
wrap(JTorch.timesLong(value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorLong.timesAssign(value: Long): Unit =
|
||||
JTorch.timesLongAssign(value, this.tensorHandle)
|
||||
|
||||
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||
wrap(JTorch.fullLong(value, shape, device.toInt()))
|
||||
}
|
||||
|
||||
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||
TorchTensorAlgebraJVM<Int, IntArray, TorchTensorInt>(scope) {
|
||||
override fun wrap(tensorHandle: Long): TorchTensorInt =
|
||||
TorchTensorInt(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
override fun TorchTensorInt.copyToArray(): IntArray =
|
||||
this.elements().map { it.second }.toList().toIntArray()
|
||||
|
||||
override fun copyFromArray(array: IntArray, shape: IntArray, device: Device): TorchTensorInt =
|
||||
wrap(JTorch.fromBlobInt(array, shape, device.toInt()))
|
||||
|
||||
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorInt =
|
||||
wrap(JTorch.randintInt(low, high, shape, device.toInt()))
|
||||
|
||||
override operator fun Int.plus(other: TorchTensorInt): TorchTensorInt =
|
||||
wrap(JTorch.plusInt(this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorInt.plus(value: Int): TorchTensorInt =
|
||||
wrap(JTorch.plusInt(value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorInt.plusAssign(value: Int): Unit =
|
||||
JTorch.plusIntAssign(value, this.tensorHandle)
|
||||
|
||||
override operator fun Int.minus(other: TorchTensorInt): TorchTensorInt =
|
||||
wrap(JTorch.plusInt(-this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorInt.minus(value: Int): TorchTensorInt =
|
||||
wrap(JTorch.plusInt(-value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorInt.minusAssign(value: Int): Unit =
|
||||
JTorch.plusIntAssign(-value, this.tensorHandle)
|
||||
|
||||
override operator fun Int.times(other: TorchTensorInt): TorchTensorInt =
|
||||
wrap(JTorch.timesInt(this, other.tensorHandle))
|
||||
|
||||
override fun TorchTensorInt.times(value: Int): TorchTensorInt =
|
||||
wrap(JTorch.timesInt(value, this.tensorHandle))
|
||||
|
||||
override fun TorchTensorInt.timesAssign(value: Int): Unit =
|
||||
JTorch.timesIntAssign(value, this.tensorHandle)
|
||||
|
||||
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
||||
wrap(JTorch.fullInt(value, shape, device.toInt()))
|
||||
}
|
||||
|
||||
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||
withDeferScope { TorchTensorRealAlgebra(this).block() }
|
||||
|
||||
public inline fun <R> TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R =
|
||||
withDeferScope { TorchTensorFloatAlgebra(this).block() }
|
||||
|
||||
public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R =
|
||||
withDeferScope { TorchTensorLongAlgebra(this).block() }
|
||||
|
||||
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
|
||||
withDeferScope { TorchTensorIntAlgebra(this).block() }
|
@ -1,4 +1,94 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
public class TorchTensorJVM {
|
||||
import kscience.kmath.memory.DeferScope
|
||||
|
||||
public sealed class TorchTensorJVM<T> constructor(
|
||||
scope: DeferScope,
|
||||
internal val tensorHandle: Long
|
||||
) : TorchTensor<T>, TorchTensorMemoryHolder(scope)
|
||||
{
|
||||
override fun close(): Unit = JTorch.disposeTensor(tensorHandle)
|
||||
|
||||
override val dimension: Int get() = JTorch.getDim(tensorHandle)
|
||||
override val shape: IntArray
|
||||
get() = (1..dimension).map { JTorch.getShapeAt(tensorHandle, it - 1) }.toIntArray()
|
||||
override val strides: IntArray
|
||||
get() = (1..dimension).map { JTorch.getStrideAt(tensorHandle, it - 1) }.toIntArray()
|
||||
override val size: Int get() = JTorch.getNumel(tensorHandle)
|
||||
override val device: Device get() = Device.fromInt(JTorch.getDevice(tensorHandle))
|
||||
|
||||
override fun toString(): String = JTorch.tensorToString(tensorHandle)
|
||||
|
||||
public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = JTorch.copyToDouble(this.tensorHandle)
|
||||
)
|
||||
|
||||
public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat(
|
||||
scope = scope,
|
||||
tensorHandle = JTorch.copyToFloat(this.tensorHandle)
|
||||
)
|
||||
|
||||
public fun copyToLong(): TorchTensorLong = TorchTensorLong(
|
||||
scope = scope,
|
||||
tensorHandle = JTorch.copyToLong(this.tensorHandle)
|
||||
)
|
||||
|
||||
public fun copyToInt(): TorchTensorInt = TorchTensorInt(
|
||||
scope = scope,
|
||||
tensorHandle = JTorch.copyToInt(this.tensorHandle)
|
||||
)
|
||||
}
|
||||
|
||||
public sealed class TorchTensorOverFieldJVM<T> constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: Long
|
||||
) : TorchTensorJVM<T>(scope, tensorHandle), TorchTensorOverField<T> {
|
||||
override var requiresGrad: Boolean
|
||||
get() = JTorch.requiresGrad(tensorHandle)
|
||||
set(value) = JTorch.setRequiresGrad(tensorHandle, value)
|
||||
}
|
||||
|
||||
public class TorchTensorReal internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: Long
|
||||
) : TorchTensorOverFieldJVM<Double>(scope, tensorHandle) {
|
||||
override fun item(): Double = JTorch.getItemDouble(tensorHandle)
|
||||
override fun get(index: IntArray): Double = JTorch.getDouble(tensorHandle, index)
|
||||
override fun set(index: IntArray, value: Double) {
|
||||
JTorch.setDouble(tensorHandle, index, value)
|
||||
}
|
||||
}
|
||||
|
||||
public class TorchTensorFloat internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: Long
|
||||
) : TorchTensorOverFieldJVM<Float>(scope, tensorHandle) {
|
||||
override fun item(): Float = JTorch.getItemFloat(tensorHandle)
|
||||
override fun get(index: IntArray): Float = JTorch.getFloat(tensorHandle, index)
|
||||
override fun set(index: IntArray, value: Float) {
|
||||
JTorch.setFloat(tensorHandle, index, value)
|
||||
}
|
||||
}
|
||||
|
||||
public class TorchTensorLong internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: Long
|
||||
) : TorchTensorOverFieldJVM<Long>(scope, tensorHandle) {
|
||||
override fun item(): Long = JTorch.getItemLong(tensorHandle)
|
||||
override fun get(index: IntArray): Long = JTorch.getLong(tensorHandle, index)
|
||||
override fun set(index: IntArray, value: Long) {
|
||||
JTorch.setLong(tensorHandle, index, value)
|
||||
}
|
||||
}
|
||||
|
||||
public class TorchTensorInt internal constructor(
|
||||
scope: DeferScope,
|
||||
tensorHandle: Long
|
||||
) : TorchTensorOverFieldJVM<Int>(scope, tensorHandle) {
|
||||
override fun item(): Int = JTorch.getItemInt(tensorHandle)
|
||||
override fun get(index: IntArray): Int = JTorch.getInt(tensorHandle, index)
|
||||
override fun set(index: IntArray, value: Int) {
|
||||
JTorch.setInt(tensorHandle, index, value)
|
||||
}
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
|
||||
|
||||
class BenchmarkMatMul {
|
||||
|
||||
@Test
|
||||
fun benchmarkMatMulDouble() = TorchTensorRealAlgebra {
|
||||
benchmarkMatMul(20, 10, 100000, "Real")
|
||||
benchmarkMatMul(200, 10, 10000, "Real")
|
||||
benchmarkMatMul(2000, 3, 20, "Real")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun benchmarkMatMulFloat() = TorchTensorFloatAlgebra {
|
||||
benchmarkMatMul(20, 10, 100000, "Float")
|
||||
benchmarkMatMul(200, 10, 10000, "Float")
|
||||
benchmarkMatMul(2000, 3, 20, "Float")
|
||||
if (cudaAvailable()) {
|
||||
benchmarkMatMul(20, 10, 100000, "Float", Device.CUDA(0))
|
||||
benchmarkMatMul(200, 10, 10000, "Float", Device.CUDA(0))
|
||||
benchmarkMatMul(2000, 10, 1000, "Float", Device.CUDA(0))
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
|
||||
|
||||
class BenchmarkRandomGenerators {
|
||||
@Test
|
||||
fun benchmarkRand1() = TorchTensorFloatAlgebra{
|
||||
benchmarkingRand1()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun benchmarkRand3() = TorchTensorFloatAlgebra{
|
||||
benchmarkingRand3()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun benchmarkRand5() = TorchTensorFloatAlgebra{
|
||||
benchmarkingRand5()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun benchmarkRand7() = TorchTensorFloatAlgebra{
|
||||
benchmarkingRand7()
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
|
||||
|
||||
class TestAutograd {
|
||||
@Test
|
||||
fun testAutoGrad() = TorchTensorFloatAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingAutoGrad(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBatchedAutoGrad() = TorchTensorFloatAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingBatchedAutoGrad(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,39 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kotlin.test.*
|
||||
|
||||
|
||||
class TestTorchTensor {
|
||||
|
||||
@Test
|
||||
fun testCopying() = TorchTensorFloatAlgebra {
|
||||
withCuda { device ->
|
||||
testingCopying(device)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testRequiresGrad() = TorchTensorRealAlgebra {
|
||||
testingRequiresGrad()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTypeMoving() = TorchTensorFloatAlgebra {
|
||||
val tensorInt = copyFromArray(floatArrayOf(1f, 2f, 3f), intArrayOf(3)).copyToInt()
|
||||
TorchTensorIntAlgebra {
|
||||
val temporalTensor = copyFromArray(intArrayOf(4, 5, 6), intArrayOf(3))
|
||||
tensorInt swap temporalTensor
|
||||
assertTrue(temporalTensor.copyToArray() contentEquals intArrayOf(1, 2, 3))
|
||||
}
|
||||
assertTrue(tensorInt.copyToFloat().copyToArray() contentEquals floatArrayOf(4f, 5f, 6f))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testViewWithNoCopy() = TorchTensorIntAlgebra {
|
||||
withChecks {
|
||||
withCuda {
|
||||
device -> testingViewWithNoCopy(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,63 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
|
||||
|
||||
class TestTorchTensorAlgebra {
|
||||
|
||||
@Test
|
||||
fun testScalarProduct() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingScalarProduct(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMatrixMultiplication() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingMatrixMultiplication(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testLinearStructure() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingLinearStructure(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTensorTransformations() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingTensorTransformations(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBatchedSVD() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingBatchedSVD(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBatchedSymEig() = TorchTensorRealAlgebra {
|
||||
withChecks {
|
||||
withCuda { device ->
|
||||
testingBatchedSymEig(device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -6,10 +6,16 @@ import kotlin.test.*
|
||||
class TestUtils {
|
||||
|
||||
@Test
|
||||
fun testJTorch() {
|
||||
val tensor = JTorch.fullInt(54, intArrayOf(3), 0)
|
||||
println(JTorch.tensorToString(tensor))
|
||||
JTorch.disposeTensor(tensor)
|
||||
fun testSetNumThreads() {
|
||||
TorchTensorLongAlgebra {
|
||||
testingSetNumThreads()
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSeedSetting() = TorchTensorFloatAlgebra {
|
||||
withCuda { device ->
|
||||
testingSetSeed(device)
|
||||
}
|
||||
}
|
||||
}
|
@ -31,7 +31,6 @@ public sealed class TorchTensorAlgebraNative<
|
||||
set_seed(seed)
|
||||
}
|
||||
|
||||
|
||||
override var checks: Boolean = false
|
||||
|
||||
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensorType
|
||||
@ -108,21 +107,16 @@ public sealed class TorchTensorAlgebraNative<
|
||||
}
|
||||
|
||||
override fun TorchTensorType.abs(): TorchTensorType = wrap(abs_tensor(tensorHandle)!!)
|
||||
override fun TorchTensorType.absAssign(): Unit {
|
||||
abs_tensor_assign(tensorHandle)
|
||||
}
|
||||
override fun TorchTensorType.absAssign(): Unit = abs_tensor_assign(tensorHandle)
|
||||
|
||||
override fun TorchTensorType.sum(): TorchTensorType = wrap(sum_tensor(tensorHandle)!!)
|
||||
override fun TorchTensorType.sumAssign(): Unit {
|
||||
sum_tensor_assign(tensorHandle)
|
||||
}
|
||||
override fun TorchTensorType.sumAssign(): Unit = sum_tensor_assign(tensorHandle)
|
||||
|
||||
override fun TorchTensorType.randIntegral(low: Long, high: Long): TorchTensorType =
|
||||
wrap(randint_like(this.tensorHandle, low, high)!!)
|
||||
|
||||
override fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit {
|
||||
override fun TorchTensorType.randIntegralAssign(low: Long, high: Long): Unit =
|
||||
randint_like_assign(this.tensorHandle, low, high)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.copy(): TorchTensorType =
|
||||
wrap(copy_tensor(this.tensorHandle)!!)
|
||||
@ -130,10 +124,9 @@ public sealed class TorchTensorAlgebraNative<
|
||||
override fun TorchTensorType.copyToDevice(device: Device): TorchTensorType =
|
||||
wrap(copy_to_device(this.tensorHandle, device.toInt())!!)
|
||||
|
||||
override infix fun TorchTensorType.swap(other: TorchTensorType): Unit {
|
||||
override infix fun TorchTensorType.swap(other: TorchTensorType): Unit =
|
||||
swap_tensors(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
}
|
||||
|
||||
public sealed class TorchTensorPartialDivisionAlgebraNative<T, TVar : CPrimitiveVar,
|
||||
PrimitiveArrayType, TorchTensorType : TorchTensorOverFieldNative<T>>(scope: DeferScope) :
|
||||
@ -153,26 +146,21 @@ public sealed class TorchTensorPartialDivisionAlgebraNative<T, TVar : CPrimitive
|
||||
override fun TorchTensorType.randUniform(): TorchTensorType =
|
||||
wrap(rand_like(this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorType.randUniformAssign(): Unit {
|
||||
override fun TorchTensorType.randUniformAssign(): Unit =
|
||||
rand_like_assign(this.tensorHandle)
|
||||
}
|
||||
|
||||
|
||||
override fun TorchTensorType.randNormal(): TorchTensorType =
|
||||
wrap(randn_like(this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorType.randNormalAssign(): Unit {
|
||||
override fun TorchTensorType.randNormalAssign(): Unit =
|
||||
randn_like_assign(this.tensorHandle)
|
||||
}
|
||||
|
||||
|
||||
override fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
|
||||
override fun TorchTensorType.expAssign(): Unit {
|
||||
exp_tensor_assign(tensorHandle)
|
||||
}
|
||||
|
||||
override fun TorchTensorType.expAssign(): Unit = exp_tensor_assign(tensorHandle)
|
||||
override fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!)
|
||||
override fun TorchTensorType.logAssign(): Unit {
|
||||
log_tensor_assign(tensorHandle)
|
||||
}
|
||||
override fun TorchTensorType.logAssign(): Unit = log_tensor_assign(tensorHandle)
|
||||
|
||||
override fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType> {
|
||||
val U = empty_tensor()!!
|
||||
@ -200,7 +188,7 @@ public sealed class TorchTensorPartialDivisionAlgebraNative<T, TVar : CPrimitive
|
||||
}
|
||||
|
||||
override fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
||||
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
|
||||
wrap(detach_from_graph(this.tensorHandle)!!)
|
||||
|
||||
}
|
||||
|
||||
@ -305,9 +293,8 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||
override fun TorchTensorFloat.plus(value: Float): TorchTensorFloat =
|
||||
wrap(plus_float(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorFloat.plusAssign(value: Float): Unit {
|
||||
override fun TorchTensorFloat.plusAssign(value: Float): Unit =
|
||||
plus_float_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Float.minus(other: TorchTensorFloat): TorchTensorFloat =
|
||||
wrap(plus_float(-this, other.tensorHandle)!!)
|
||||
@ -315,9 +302,8 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||
override fun TorchTensorFloat.minus(value: Float): TorchTensorFloat =
|
||||
wrap(plus_float(-value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorFloat.minusAssign(value: Float): Unit {
|
||||
override fun TorchTensorFloat.minusAssign(value: Float): Unit =
|
||||
plus_float_assign(-value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Float.times(other: TorchTensorFloat): TorchTensorFloat =
|
||||
wrap(times_float(this, other.tensorHandle)!!)
|
||||
@ -325,9 +311,8 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||
override fun TorchTensorFloat.times(value: Float): TorchTensorFloat =
|
||||
wrap(times_float(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorFloat.timesAssign(value: Float): Unit {
|
||||
override fun TorchTensorFloat.timesAssign(value: Float): Unit =
|
||||
times_float_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
|
||||
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
@ -364,9 +349,8 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||
override fun TorchTensorLong.plus(value: Long): TorchTensorLong =
|
||||
wrap(plus_long(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorLong.plusAssign(value: Long): Unit {
|
||||
override fun TorchTensorLong.plusAssign(value: Long): Unit =
|
||||
plus_long_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Long.minus(other: TorchTensorLong): TorchTensorLong =
|
||||
wrap(plus_long(-this, other.tensorHandle)!!)
|
||||
@ -374,9 +358,8 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||
override fun TorchTensorLong.minus(value: Long): TorchTensorLong =
|
||||
wrap(plus_long(-value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorLong.minusAssign(value: Long): Unit {
|
||||
override fun TorchTensorLong.minusAssign(value: Long): Unit =
|
||||
plus_long_assign(-value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Long.times(other: TorchTensorLong): TorchTensorLong =
|
||||
wrap(times_long(this, other.tensorHandle)!!)
|
||||
@ -384,9 +367,8 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||
override fun TorchTensorLong.times(value: Long): TorchTensorLong =
|
||||
wrap(times_long(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorLong.timesAssign(value: Long): Unit {
|
||||
override fun TorchTensorLong.timesAssign(value: Long): Unit =
|
||||
times_long_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
|
||||
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
@ -422,9 +404,8 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||
override fun TorchTensorInt.plus(value: Int): TorchTensorInt =
|
||||
wrap(plus_int(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorInt.plusAssign(value: Int): Unit {
|
||||
override fun TorchTensorInt.plusAssign(value: Int): Unit =
|
||||
plus_int_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Int.minus(other: TorchTensorInt): TorchTensorInt =
|
||||
wrap(plus_int(-this, other.tensorHandle)!!)
|
||||
@ -432,9 +413,8 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||
override fun TorchTensorInt.minus(value: Int): TorchTensorInt =
|
||||
wrap(plus_int(-value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorInt.minusAssign(value: Int): Unit {
|
||||
override fun TorchTensorInt.minusAssign(value: Int): Unit =
|
||||
plus_int_assign(-value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Int.times(other: TorchTensorInt): TorchTensorInt =
|
||||
wrap(times_int(this, other.tensorHandle)!!)
|
||||
@ -442,9 +422,8 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||
override fun TorchTensorInt.times(value: Int): TorchTensorInt =
|
||||
wrap(times_int(value, this.tensorHandle)!!)
|
||||
|
||||
override fun TorchTensorInt.timesAssign(value: Int): Unit {
|
||||
override fun TorchTensorInt.timesAssign(value: Int): Unit =
|
||||
times_int_assign(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
|
||||
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||
|
@ -47,7 +47,6 @@ public sealed class TorchTensorNative<T> constructor(
|
||||
scope = scope,
|
||||
tensorHandle = copy_to_int(this.tensorHandle)!!
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
public sealed class TorchTensorOverFieldNative<T> constructor(
|
||||
|
@ -2,6 +2,7 @@ package kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
|
||||
|
||||
internal class BenchmarkMatMul {
|
||||
|
||||
@Test
|
||||
|
@ -1,6 +1,6 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kotlin.test.*
|
||||
import kotlin.test.Test
|
||||
|
||||
|
||||
internal class TestAutograd {
|
||||
|
@ -1,6 +1,6 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kotlin.test.*
|
||||
import kotlin.test.Test
|
||||
|
||||
|
||||
internal class TestTorchTensorAlgebra {
|
||||
|
Loading…
Reference in New Issue
Block a user