Adding systematic checks

This commit is contained in:
rgrit91 2021-01-16 15:52:05 +00:00 committed by Roland Grinis
parent ca2082405a
commit 4f4fcba559
21 changed files with 413 additions and 192 deletions

View File

@ -67,6 +67,8 @@ extern "C"
TorchTensorHandle rand_float(int *shape, int shape_size, int device); TorchTensorHandle rand_float(int *shape, int shape_size, int device);
TorchTensorHandle randn_float(int *shape, int shape_size, int device); TorchTensorHandle randn_float(int *shape, int shape_size, int device);
TorchTensorHandle randint_double(long low, long high, int *shape, int shape_size, int device);
TorchTensorHandle randint_float(long low, long high, int *shape, int shape_size, int device);
TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device); TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device);
TorchTensorHandle randint_int(int low, int high, int *shape, int shape_size, int device); TorchTensorHandle randint_int(int low, int high, int *shape, int shape_size, int device);

View File

@ -103,7 +103,7 @@ namespace ctorch
} }
template <typename Dtype> template <typename Dtype>
inline torch::Tensor randint(Dtype low, Dtype high, std::vector<int64_t> shape, torch::Device device) inline torch::Tensor randint(long low, long high, std::vector<int64_t> shape, torch::Device device)
{ {
return torch::randint(low, high, shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device)); return torch::randint(low, high, shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
} }

View File

@ -197,6 +197,14 @@ TorchTensorHandle randn_float(int *shape, int shape_size, int device)
return new torch::Tensor(ctorch::randn<float>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); return new torch::Tensor(ctorch::randn<float>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
} }
TorchTensorHandle randint_double(long low, long high, int *shape, int shape_size, int device)
{
return new torch::Tensor(ctorch::randint<double>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
}
TorchTensorHandle randint_float(long low, long high, int *shape, int shape_size, int device)
{
return new torch::Tensor(ctorch::randint<float>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
}
TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device) TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device)
{ {
return new torch::Tensor(ctorch::randint<long>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device))); return new torch::Tensor(ctorch::randint<long>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));

View File

@ -6,15 +6,15 @@ class BenchmarkMatMultGPU {
@Test @Test
fun benchmarkMatMultFloat20() = fun benchmarkMatMultFloat20() =
benchmarkingMatMultFloat(20, 10, 100000, benchmarkingMatMultFloat(20, 10, 100000,
device = TorchDevice.TorchCUDA(0)) device = Device.CUDA(0))
@Test @Test
fun benchmarkMatMultFloat200() = fun benchmarkMatMultFloat200() =
benchmarkingMatMultFloat(200, 10, 10000, benchmarkingMatMultFloat(200, 10, 10000,
device = TorchDevice.TorchCUDA(0)) device = Device.CUDA(0))
@Test @Test
fun benchmarkMatMultFloat2000() = fun benchmarkMatMultFloat2000() =
benchmarkingMatMultFloat(2000, 10, 1000, benchmarkingMatMultFloat(2000, 10, 1000,
device = TorchDevice.TorchCUDA(0)) device = Device.CUDA(0))
} }

View File

@ -6,59 +6,59 @@ class BenchmarkRandomGeneratorsGPU {
@Test @Test
fun benchmarkRandNormal1() = fun benchmarkRandNormal1() =
benchmarkingRandNormal(10, 10, 100000, benchmarkingRandNormal(10, 10, 100000,
device = TorchDevice.TorchCUDA(0)) device = Device.CUDA(0))
@Test @Test
fun benchmarkRandUniform1() = fun benchmarkRandUniform1() =
benchmarkingRandUniform(10, 10, 100000, benchmarkingRandUniform(10, 10, 100000,
device = TorchDevice.TorchCUDA(0)) device = Device.CUDA(0))
@Test @Test
fun benchmarkRandIntegral1() = fun benchmarkRandIntegral1() =
benchmarkingRandIntegral(10, 10, 100000, benchmarkingRandIntegral(10, 10, 100000,
device = TorchDevice.TorchCUDA(0)) device = Device.CUDA(0))
@Test @Test
fun benchmarkRandNormal3() = fun benchmarkRandNormal3() =
benchmarkingRandNormal(1000, 10, 100000, benchmarkingRandNormal(1000, 10, 100000,
device = TorchDevice.TorchCUDA(0)) device = Device.CUDA(0))
@Test @Test
fun benchmarkRandUniform3() = fun benchmarkRandUniform3() =
benchmarkingRandUniform(1000, 10, 100000, benchmarkingRandUniform(1000, 10, 100000,
device = TorchDevice.TorchCUDA(0)) device = Device.CUDA(0))
@Test @Test
fun benchmarkRandIntegral3() = fun benchmarkRandIntegral3() =
benchmarkingRandIntegral(1000, 10, 100000, benchmarkingRandIntegral(1000, 10, 100000,
device = TorchDevice.TorchCUDA(0)) device = Device.CUDA(0))
@Test @Test
fun benchmarkRandNormal5() = fun benchmarkRandNormal5() =
benchmarkingRandNormal(100000, 10, 100000, benchmarkingRandNormal(100000, 10, 100000,
device = TorchDevice.TorchCUDA(0)) device = Device.CUDA(0))
@Test @Test
fun benchmarkRandUniform5() = fun benchmarkRandUniform5() =
benchmarkingRandUniform(100000, 10, 100000, benchmarkingRandUniform(100000, 10, 100000,
device = TorchDevice.TorchCUDA(0)) device = Device.CUDA(0))
@Test @Test
fun benchmarkRandIntegral5() = fun benchmarkRandIntegral5() =
benchmarkingRandIntegral(100000, 10, 100000, benchmarkingRandIntegral(100000, 10, 100000,
device = TorchDevice.TorchCUDA(0)) device = Device.CUDA(0))
@Test @Test
fun benchmarkRandNormal7() = fun benchmarkRandNormal7() =
benchmarkingRandNormal(10000000, 10, 10000, benchmarkingRandNormal(10000000, 10, 10000,
device = TorchDevice.TorchCUDA(0)) device = Device.CUDA(0))
@Test @Test
fun benchmarkRandUniform7() = fun benchmarkRandUniform7() =
benchmarkingRandUniform(10000000, 10, 10000, benchmarkingRandUniform(10000000, 10, 10000,
device = TorchDevice.TorchCUDA(0)) device = Device.CUDA(0))
@Test @Test
fun benchmarkRandIntegral7() = fun benchmarkRandIntegral7() =
benchmarkingRandIntegral(10000000, 10, 10000, benchmarkingRandIntegral(10000000, 10, 10000,
device = TorchDevice.TorchCUDA(0)) device = Device.CUDA(0))
} }

View File

@ -5,9 +5,9 @@ import kotlin.test.*
internal class TestAutogradGPU { internal class TestAutogradGPU {
@Test @Test
fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0)) fun testAutoGrad() = testingAutoGrad(dim = 3, device = Device.CUDA(0))
@Test @Test
fun testBatchedAutoGrad() = testingBatchedAutoGrad( fun testBatchedAutoGrad() = testingBatchedAutoGrad(
bath = intArrayOf(2), dim=3, device = TorchDevice.TorchCUDA(0)) bath = intArrayOf(2), dim=3, device = Device.CUDA(0))
} }

View File

@ -7,26 +7,26 @@ class TestTorchTensorAlgebraGPU {
@Test @Test
fun testScalarProduct() = fun testScalarProduct() =
testingScalarProduct(device = TorchDevice.TorchCUDA(0)) testingScalarProduct(device = Device.CUDA(0))
@Test @Test
fun testMatrixMultiplication() = fun testMatrixMultiplication() =
testingMatrixMultiplication(device = TorchDevice.TorchCUDA(0)) testingMatrixMultiplication(device = Device.CUDA(0))
@Test @Test
fun testLinearStructure() = fun testLinearStructure() =
testingLinearStructure(device = TorchDevice.TorchCUDA(0)) testingLinearStructure(device = Device.CUDA(0))
@Test @Test
fun testTensorTransformations() = fun testTensorTransformations() =
testingTensorTransformations(device = TorchDevice.TorchCUDA(0)) testingTensorTransformations(device = Device.CUDA(0))
@Test @Test
fun testBatchedSVD() = fun testBatchedSVD() =
testingBatchedSVD(device = TorchDevice.TorchCUDA(0)) testingBatchedSVD(device = Device.CUDA(0))
@Test @Test
fun testBatchedSymEig() = fun testBatchedSymEig() =
testingBatchedSymEig(device = TorchDevice.TorchCUDA(0)) testingBatchedSymEig(device = Device.CUDA(0))
} }

View File

@ -5,17 +5,17 @@ import kotlin.test.*
class TestTorchTensorGPU { class TestTorchTensorGPU {
@Test @Test
fun testCopyFromArray() = testingCopyFromArray(TorchDevice.TorchCUDA(0)) fun testCopyFromArray() = testingCopyFromArray(Device.CUDA(0))
@Test @Test
fun testCopyToDevice() = TorchTensorRealAlgebra { fun testCopyToDevice() = TorchTensorRealAlgebra {
setSeed(SEED) setSeed(SEED)
val normalCpu = randNormal(intArrayOf(2, 3)) val normalCpu = randNormal(intArrayOf(2, 3))
val normalGpu = normalCpu.copyToDevice(TorchDevice.TorchCUDA(0)) val normalGpu = normalCpu.copyToDevice(Device.CUDA(0))
assertTrue(normalCpu.copyToArray() contentEquals normalGpu.copyToArray()) assertTrue(normalCpu.copyToArray() contentEquals normalGpu.copyToArray())
val uniformGpu = randUniform(intArrayOf(3,2),TorchDevice.TorchCUDA(0)) val uniformGpu = randUniform(intArrayOf(3,2),Device.CUDA(0))
val uniformCpu = uniformGpu.copyToDevice(TorchDevice.TorchCPU) val uniformCpu = uniformGpu.copyToDevice(Device.CPU)
assertTrue(uniformGpu.copyToArray() contentEquals uniformCpu.copyToArray()) assertTrue(uniformGpu.copyToArray() contentEquals uniformCpu.copyToArray())
} }

View File

@ -11,6 +11,6 @@ internal class TestUtilsGPU {
} }
@Test @Test
fun testSetSeed() = testingSetSeed(TorchDevice.TorchCUDA(0)) fun testSetSeed() = testingSetSeed(Device.CUDA(0))
} }

View File

@ -0,0 +1,22 @@
package kscience.kmath.torch
public sealed class Device {
public object CPU: Device() {
override fun toString(): String {
return "CPU"
}
}
public data class CUDA(val index: Int): Device()
public fun toInt(): Int {
when(this) {
is CPU -> return 0
is CUDA -> return this.index + 1
}
}
public companion object {
public fun fromInt(deviceInt: Int): Device {
return if (deviceInt == 0) CPU else CUDA(deviceInt-1)
}
}
}

View File

@ -0,0 +1,50 @@
package kscience.kmath.torch
import kscience.kmath.operations.Field
import kscience.kmath.operations.Ring
public interface TensorAlgebra<T, TorchTensorType : TensorStructure<T>> : Ring<TorchTensorType> {
public operator fun T.plus(other: TorchTensorType): TorchTensorType
public operator fun TorchTensorType.plus(value: T): TorchTensorType
public operator fun TorchTensorType.plusAssign(value: T): Unit
public operator fun TorchTensorType.plusAssign(b: TorchTensorType): Unit
public operator fun T.minus(other: TorchTensorType): TorchTensorType
public operator fun TorchTensorType.minus(value: T): TorchTensorType
public operator fun TorchTensorType.minusAssign(value: T): Unit
public operator fun TorchTensorType.minusAssign(b: TorchTensorType): Unit
public operator fun T.times(other: TorchTensorType): TorchTensorType
public operator fun TorchTensorType.times(value: T): TorchTensorType
public operator fun TorchTensorType.timesAssign(value: T): Unit
public operator fun TorchTensorType.timesAssign(b: TorchTensorType): Unit
public infix fun TorchTensorType.dot(b: TorchTensorType): TorchTensorType
public fun diagonalEmbedding(
diagonalEntries: TorchTensorType,
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
): TorchTensorType
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType
public fun TorchTensorType.view(shape: IntArray): TorchTensorType
public fun TorchTensorType.abs(): TorchTensorType
public fun TorchTensorType.sum(): TorchTensorType
}
public interface TensorFieldAlgebra<T, TorchTensorType : TensorStructure<T>> :
TensorAlgebra<T, TorchTensorType>, Field<TorchTensorType> {
public operator fun TorchTensorType.divAssign(b: TorchTensorType)
public fun TorchTensorType.exp(): TorchTensorType
public fun TorchTensorType.log(): TorchTensorType
public fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType>
public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair<TorchTensorType, TorchTensorType>
}

View File

@ -0,0 +1,13 @@
package kscience.kmath.torch
import kscience.kmath.structures.MutableNDStructure
public abstract class TensorStructure<T>: MutableNDStructure<T> {
// A tensor can have empty shape, in which case it represents just a value
public abstract fun value(): T
// Tensors might hold shared resources
override fun equals(other: Any?): Boolean = false
override fun hashCode(): Int = 0
}

View File

@ -1,22 +0,0 @@
package kscience.kmath.torch
public sealed class TorchDevice {
public object TorchCPU: TorchDevice() {
override fun toString(): String {
return "TorchCPU"
}
}
public data class TorchCUDA(val index: Int): TorchDevice()
public fun toInt(): Int {
when(this) {
is TorchCPU -> return 0
is TorchCUDA -> return this.index + 1
}
}
public companion object {
public fun fromInt(deviceInt: Int): TorchDevice {
return if (deviceInt == 0) TorchCPU else TorchCUDA(deviceInt-1)
}
}
}

View File

@ -1,6 +1,5 @@
package kscience.kmath.torch package kscience.kmath.torch
import kscience.kmath.structures.MutableNDStructure
import kotlinx.cinterop.* import kotlinx.cinterop.*
import kscience.kmath.ctorch.* import kscience.kmath.ctorch.*
@ -9,7 +8,7 @@ import kscience.kmath.ctorch.*
public sealed class TorchTensor<T> constructor( public sealed class TorchTensor<T> constructor(
internal val scope: DeferScope, internal val scope: DeferScope,
internal val tensorHandle: COpaquePointer internal val tensorHandle: COpaquePointer
) : MutableNDStructure<T> { ) : TensorStructure<T>() {
init { init {
scope.defer(::close) scope.defer(::close)
} }
@ -23,10 +22,8 @@ public sealed class TorchTensor<T> constructor(
public val strides: IntArray public val strides: IntArray
get() = (1..dimension).map{get_stride_at(tensorHandle, it-1)}.toIntArray() get() = (1..dimension).map{get_stride_at(tensorHandle, it-1)}.toIntArray()
public val size: Int get() = get_numel(tensorHandle) public val size: Int get() = get_numel(tensorHandle)
public val device: TorchDevice get() = TorchDevice.fromInt(get_device(tensorHandle)) public val device: Device get() = Device.fromInt(get_device(tensorHandle))
override fun equals(other: Any?): Boolean = false
override fun hashCode(): Int = 0
override fun toString(): String { override fun toString(): String {
val nativeStringRepresentation: CPointer<ByteVar> = tensor_to_string(tensorHandle)!! val nativeStringRepresentation: CPointer<ByteVar> = tensor_to_string(tensorHandle)!!
val stringRepresentation = nativeStringRepresentation.toKString() val stringRepresentation = nativeStringRepresentation.toKString()
@ -42,12 +39,13 @@ public sealed class TorchTensor<T> constructor(
return indices.map { it to get(it) } return indices.map { it to get(it) }
} }
internal inline fun isValue() = check(dimension == 0) { public inline fun isValue(): Boolean = dimension == 0
public inline fun isNotValue(): Boolean = !isValue()
internal inline fun checkIsValue() = check(isValue()) {
"This tensor has shape ${shape.toList()}" "This tensor has shape ${shape.toList()}"
} }
override fun value(): T {
public fun value(): T { checkIsValue()
isValue()
return item() return item()
} }

View File

@ -9,13 +9,15 @@ public sealed class TorchTensorAlgebra<
PrimitiveArrayType, PrimitiveArrayType,
TorchTensorType : TorchTensor<T>> constructor( TorchTensorType : TorchTensor<T>> constructor(
internal val scope: DeferScope internal val scope: DeferScope
) { ) :
TensorAlgebra<T, TorchTensorType> {
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensorType internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensorType
public abstract fun copyFromArray( public abstract fun copyFromArray(
array: PrimitiveArrayType, array: PrimitiveArrayType,
shape: IntArray, shape: IntArray,
device: TorchDevice = TorchDevice.TorchCPU device: Device = Device.CPU
): TorchTensorType ): TorchTensorType
public abstract fun TorchTensorType.copyToArray(): PrimitiveArrayType public abstract fun TorchTensorType.copyToArray(): PrimitiveArrayType
@ -23,96 +25,219 @@ public sealed class TorchTensorAlgebra<
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType
public abstract fun TorchTensorType.getData(): CPointer<TVar> public abstract fun TorchTensorType.getData(): CPointer<TVar>
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensorType public abstract fun full(value: T, shape: IntArray, device: Device): TorchTensorType
public abstract operator fun T.plus(other: TorchTensorType): TorchTensorType public abstract fun randIntegral(
public abstract operator fun TorchTensorType.plus(value: T): TorchTensorType low: T, high: T, shape: IntArray,
public abstract operator fun TorchTensorType.plusAssign(value: T): Unit device: Device = Device.CPU
public abstract operator fun T.minus(other: TorchTensorType): TorchTensorType ): TorchTensorType
public abstract operator fun TorchTensorType.minus(value: T): TorchTensorType
public abstract operator fun TorchTensorType.minusAssign(value: T): Unit
public abstract operator fun T.times(other: TorchTensorType): TorchTensorType
public abstract operator fun TorchTensorType.times(value: T): TorchTensorType
public abstract operator fun TorchTensorType.timesAssign(value: T): Unit
public operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType = public abstract fun TorchTensorType.randIntegral(low: T, high: T): TorchTensorType
wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!) public abstract fun TorchTensorType.randIntegralAssign(low: T, high: T): Unit
public operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit { override val zero: TorchTensorType
times_tensor_assign(this.tensorHandle, other.tensorHandle) get() = number(0)
override val one: TorchTensorType
get() = number(1)
protected inline fun checkDeviceCompatible(a: TorchTensorType, b: TorchTensorType) =
check(a.device == b.device) {
"Tensors must be on the same device"
} }
public infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType = protected inline fun checkShapeCompatible(a: TorchTensorType, b: TorchTensorType) =
wrap(matmul(this.tensorHandle, other.tensorHandle)!!) check(a.shape contentEquals b.shape) {
"Tensors must be of identical shape"
public infix fun TorchTensorType.dotAssign(other: TorchTensorType): Unit {
matmul_assign(this.tensorHandle, other.tensorHandle)
} }
public infix fun TorchTensorType.dotRightAssign(other: TorchTensorType): Unit { protected inline fun checkLinearOperation(a: TorchTensorType, b: TorchTensorType) {
matmul_right_assign(this.tensorHandle, other.tensorHandle) if (a.isNotValue() and b.isNotValue()) {
checkDeviceCompatible(a, b)
checkShapeCompatible(a, b)
}
} }
public operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType = override operator fun TorchTensorType.times(b: TorchTensorType): TorchTensorType =
wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!) this.times(b, safe = true)
public operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit { public fun TorchTensorType.times(b: TorchTensorType, safe: Boolean): TorchTensorType {
plus_tensor_assign(this.tensorHandle, other.tensorHandle) if (safe) checkLinearOperation(this, b)
return wrap(times_tensor(this.tensorHandle, b.tensorHandle)!!)
} }
public operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType = override operator fun TorchTensorType.timesAssign(b: TorchTensorType): Unit =
wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!) this.timesAssign(b, safe = true)
public operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit { public fun TorchTensorType.timesAssign(b: TorchTensorType, safe: Boolean): Unit {
minus_tensor_assign(this.tensorHandle, other.tensorHandle) if (safe) checkLinearOperation(this, b)
times_tensor_assign(this.tensorHandle, b.tensorHandle)
} }
public operator fun TorchTensorType.unaryMinus(): TorchTensorType = override fun multiply(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a * b
override operator fun TorchTensorType.plus(b: TorchTensorType): TorchTensorType =
this.plus(b, safe = true)
public fun TorchTensorType.plus(b: TorchTensorType, safe: Boolean): TorchTensorType {
if (safe) checkLinearOperation(this, b)
return wrap(plus_tensor(this.tensorHandle, b.tensorHandle)!!)
}
override operator fun TorchTensorType.plusAssign(b: TorchTensorType): Unit =
this.plusAssign(b, false)
public fun TorchTensorType.plusAssign(b: TorchTensorType, safe: Boolean): Unit {
if (safe) checkLinearOperation(this, b)
plus_tensor_assign(this.tensorHandle, b.tensorHandle)
}
override fun add(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a + b
override operator fun TorchTensorType.minus(b: TorchTensorType): TorchTensorType =
this.minus(b, safe = true)
public fun TorchTensorType.minus(b: TorchTensorType, safe: Boolean): TorchTensorType {
if (safe) checkLinearOperation(this, b)
return wrap(minus_tensor(this.tensorHandle, b.tensorHandle)!!)
}
override operator fun TorchTensorType.minusAssign(b: TorchTensorType): Unit =
this.minusAssign(b, safe = true)
public fun TorchTensorType.minusAssign(b: TorchTensorType, safe: Boolean): Unit {
if (safe) checkLinearOperation(this, b)
minus_tensor_assign(this.tensorHandle, b.tensorHandle)
}
override operator fun TorchTensorType.unaryMinus(): TorchTensorType =
wrap(unary_minus(this.tensorHandle)!!) wrap(unary_minus(this.tensorHandle)!!)
public fun TorchTensorType.abs(): TorchTensorType = wrap(abs_tensor(tensorHandle)!!) private inline fun checkDotOperation(a: TorchTensorType, b: TorchTensorType): Unit {
checkDeviceCompatible(a, b)
val sa = a.shape
val sb = b.shape
val na = sa.size
val nb = sb.size
var status: Boolean
if (nb == 1) {
status = sa.last() == sb[0]
} else {
status = sa.last() == sb[nb - 2]
if ((na > 2) and (nb > 2)) {
status = status and
(sa.take(nb - 2).toIntArray() contentEquals sb.take(nb - 2).toIntArray())
}
}
check(status) { "Incompatible shapes $sa and $sb for dot product" }
}
override infix fun TorchTensorType.dot(b: TorchTensorType): TorchTensorType =
this.dot(b, safe = true)
public fun TorchTensorType.dot(b: TorchTensorType, safe: Boolean): TorchTensorType {
if (safe) checkDotOperation(this, b)
return wrap(matmul(this.tensorHandle, b.tensorHandle)!!)
}
public infix fun TorchTensorType.dotAssign(b: TorchTensorType): Unit =
this.dotAssign(b, safe = true)
public fun TorchTensorType.dotAssign(b: TorchTensorType, safe: Boolean): Unit {
if (safe) checkDotOperation(this, b)
matmul_assign(this.tensorHandle, b.tensorHandle)
}
public infix fun TorchTensorType.dotRightAssign(b: TorchTensorType): Unit =
this.dotRightAssign(b, safe = true)
public fun TorchTensorType.dotRightAssign(b: TorchTensorType, safe: Boolean): Unit {
if (safe) checkDotOperation(this, b)
matmul_right_assign(this.tensorHandle, b.tensorHandle)
}
override fun diagonalEmbedding(
diagonalEntries: TorchTensorType, offset: Int, dim1: Int, dim2: Int
): TorchTensorType =
wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!)
private inline fun checkTranspose(dim: Int, i: Int, j: Int): Unit =
check((i < dim) and (j < dim)) {
"Cannot transpose $i to $j for a tensor of dim $dim"
}
override fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType =
this.transpose(i, j, safe = true)
public fun TorchTensorType.transpose(i: Int, j: Int, safe: Boolean): TorchTensorType {
if (safe) checkTranspose(this.dimension, i, j)
return wrap(transpose_tensor(tensorHandle, i, j)!!)
}
public fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit =
this.transposeAssign(i, j, safe = true)
public fun TorchTensorType.transposeAssign(i: Int, j: Int, safe: Boolean): Unit {
if (safe) checkTranspose(this.dimension, i, j)
transpose_tensor_assign(tensorHandle, i, j)
}
private inline fun checkView(a: TorchTensorType, shape: IntArray): Unit =
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
override fun TorchTensorType.view(shape: IntArray): TorchTensorType =
this.view(shape, safe = true)
public fun TorchTensorType.view(shape: IntArray, safe: Boolean): TorchTensorType {
if (safe) checkView(this, shape)
return wrap(view_tensor(this.tensorHandle, shape.toCValues(), shape.size)!!)
}
override fun TorchTensorType.abs(): TorchTensorType = wrap(abs_tensor(tensorHandle)!!)
public fun TorchTensorType.absAssign(): Unit { public fun TorchTensorType.absAssign(): Unit {
abs_tensor_assign(tensorHandle) abs_tensor_assign(tensorHandle)
} }
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType = override fun TorchTensorType.sum(): TorchTensorType = wrap(sum_tensor(tensorHandle)!!)
wrap(transpose_tensor(tensorHandle, i, j)!!)
public fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
transpose_tensor_assign(tensorHandle, i, j)
}
public fun TorchTensorType.sum(): TorchTensorType = wrap(sum_tensor(tensorHandle)!!)
public fun TorchTensorType.sumAssign(): Unit { public fun TorchTensorType.sumAssign(): Unit {
sum_tensor_assign(tensorHandle) sum_tensor_assign(tensorHandle)
} }
public fun diagEmbed(
diagonalEntries: TorchTensorType,
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
): TorchTensorType =
wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!)
public fun TorchTensorType.copy(): TorchTensorType = public fun TorchTensorType.copy(): TorchTensorType =
wrap(copy_tensor(this.tensorHandle)!!) wrap(copy_tensor(this.tensorHandle)!!)
public fun TorchTensorType.copyToDevice(device: TorchDevice): TorchTensorType = public fun TorchTensorType.copyToDevice(device: Device): TorchTensorType =
wrap(copy_to_device(this.tensorHandle, device.toInt())!!) wrap(copy_to_device(this.tensorHandle, device.toInt())!!)
public infix fun TorchTensorType.swap(otherTensor: TorchTensorType): Unit { public infix fun TorchTensorType.swap(otherTensor: TorchTensorType): Unit {
swap_tensors(this.tensorHandle, otherTensor.tensorHandle) swap_tensors(this.tensorHandle, otherTensor.tensorHandle)
} }
public fun TorchTensorType.view(shape: IntArray): TorchTensorType =
wrap(view_tensor(this.tensorHandle, shape.toCValues(), shape.size)!!)
} }
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar, public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) : PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) :
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope) { TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope),
TensorFieldAlgebra<T, TorchTensorType> {
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType override operator fun TorchTensorType.div(b: TorchTensorType): TorchTensorType =
public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType this.div(b, safe = true)
public fun TorchTensorType.div(b: TorchTensorType, safe: Boolean): TorchTensorType {
if (safe) checkLinearOperation(this, b)
return wrap(div_tensor(this.tensorHandle, b.tensorHandle)!!)
}
override operator fun TorchTensorType.divAssign(b: TorchTensorType): Unit =
this.divAssign(b, safe = true)
public fun TorchTensorType.divAssign(b: TorchTensorType, safe: Boolean): Unit {
if (safe) checkLinearOperation(this, b)
div_tensor_assign(this.tensorHandle, b.tensorHandle)
}
override fun divide(a: TorchTensorType, b: TorchTensorType): TorchTensorType = a / b
public abstract fun randUniform(shape: IntArray, device: Device = Device.CPU): TorchTensorType
public abstract fun randNormal(shape: IntArray, device: Device = Device.CPU): TorchTensorType
public fun TorchTensorType.randUniform(): TorchTensorType = public fun TorchTensorType.randUniform(): TorchTensorType =
wrap(rand_like(this.tensorHandle)!!) wrap(rand_like(this.tensorHandle)!!)
@ -128,24 +253,17 @@ public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
randn_like_assign(this.tensorHandle) randn_like_assign(this.tensorHandle)
} }
public operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType = override fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!)
public operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
div_tensor_assign(this.tensorHandle, other.tensorHandle)
}
public fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
public fun TorchTensorType.expAssign(): Unit { public fun TorchTensorType.expAssign(): Unit {
exp_tensor_assign(tensorHandle) exp_tensor_assign(tensorHandle)
} }
public fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!) override fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!)
public fun TorchTensorType.logAssign(): Unit { public fun TorchTensorType.logAssign(): Unit {
log_tensor_assign(tensorHandle) log_tensor_assign(tensorHandle)
} }
public fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType> { override fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType> {
val U = empty_tensor()!! val U = empty_tensor()!!
val V = empty_tensor()!! val V = empty_tensor()!!
val S = empty_tensor()!! val S = empty_tensor()!!
@ -153,7 +271,7 @@ public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
return Triple(wrap(U), wrap(S), wrap(V)) return Triple(wrap(U), wrap(S), wrap(V))
} }
public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair<TorchTensorType, TorchTensorType> { override fun TorchTensorType.symEig(eigenvectors: Boolean): Pair<TorchTensorType, TorchTensorType> {
val V = empty_tensor()!! val V = empty_tensor()!!
val S = empty_tensor()!! val S = empty_tensor()!!
symeig_tensor(this.tensorHandle, S, V, eigenvectors) symeig_tensor(this.tensorHandle, S, V, eigenvectors)
@ -161,13 +279,15 @@ public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
} }
public fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean = false): TorchTensorType { public fun TorchTensorType.grad(variable: TorchTensorType, retainGraph: Boolean = false): TorchTensorType {
this.isValue() this.checkIsValue()
return wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle, retainGraph)!!) return wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle, retainGraph)!!)
} }
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType = public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
this.grad(variable, false) this.grad(variable, false)
public infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType { public infix fun TorchTensorType.hess(variable: TorchTensorType): TorchTensorType {
this.isValue() this.checkIsValue()
return wrap(autohess_tensor(this.tensorHandle, variable.tensorHandle)!!) return wrap(autohess_tensor(this.tensorHandle, variable.tensorHandle)!!)
} }
@ -175,42 +295,34 @@ public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!) wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
} }
public sealed class TorchTensorRingAlgebra<T, TVar : CPrimitiveVar,
PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) :
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope) {
public abstract fun randIntegral(
low: T, high: T, shape: IntArray,
device: TorchDevice = TorchDevice.TorchCPU
): TorchTensorType
public abstract fun TorchTensorType.randIntegral(low: T, high: T): TorchTensorType
public abstract fun TorchTensorType.randIntegralAssign(low: T, high: T): Unit
}
public class TorchTensorRealAlgebra(scope: DeferScope) : public class TorchTensorRealAlgebra(scope: DeferScope) :
TorchTensorFieldAlgebra<Double, DoubleVar, DoubleArray, TorchTensorReal>(scope) { TorchTensorFieldAlgebra<Double, DoubleVar, DoubleArray, TorchTensorReal>(scope) {
override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal = override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal =
TorchTensorReal(scope = scope, tensorHandle = tensorHandle) TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
override fun number(value: Number): TorchTensorReal =
full(value.toDouble(), intArrayOf(1), Device.CPU).sum()
override fun TorchTensorReal.copyToArray(): DoubleArray = override fun TorchTensorReal.copyToArray(): DoubleArray =
this.elements().map { it.second }.toList().toDoubleArray() this.elements().map { it.second }.toList().toDoubleArray()
override fun copyFromArray(array: DoubleArray, shape: IntArray, device: TorchDevice): TorchTensorReal = override fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): TorchTensorReal =
wrap(from_blob_double(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!) wrap(from_blob_double(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
override fun fromBlob(arrayBlob: CPointer<DoubleVar>, shape: IntArray): TorchTensorReal = override fun fromBlob(arrayBlob: CPointer<DoubleVar>, shape: IntArray): TorchTensorReal =
wrap(from_blob_double(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!) wrap(from_blob_double(arrayBlob, shape.toCValues(), shape.size, Device.CPU.toInt(), false)!!)
override fun TorchTensorReal.getData(): CPointer<DoubleVar> { override fun TorchTensorReal.getData(): CPointer<DoubleVar> {
require(this.device is TorchDevice.TorchCPU) { require(this.device is Device.CPU) {
"This tensor is not on available on CPU" "This tensor is not on available on CPU"
} }
return get_data_double(this.tensorHandle)!! return get_data_double(this.tensorHandle)!!
} }
override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorReal = override fun randNormal(shape: IntArray, device: Device): TorchTensorReal =
wrap(randn_double(shape.toCValues(), shape.size, device.toInt())!!) wrap(randn_double(shape.toCValues(), shape.size, device.toInt())!!)
override fun randUniform(shape: IntArray, device: TorchDevice): TorchTensorReal = override fun randUniform(shape: IntArray, device: Device): TorchTensorReal =
wrap(rand_double(shape.toCValues(), shape.size, device.toInt())!!) wrap(rand_double(shape.toCValues(), shape.size, device.toInt())!!)
override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal = override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal =
@ -243,8 +355,20 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
times_double_assign(value, this.tensorHandle) times_double_assign(value, this.tensorHandle)
} }
override fun full(value: Double, shape: IntArray, device: TorchDevice): TorchTensorReal = override fun multiply(a: TorchTensorReal, k: Number): TorchTensorReal = a * k.toDouble()
override fun full(value: Double, shape: IntArray, device: Device): TorchTensorReal =
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!) wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
override fun randIntegral(low: Double, high: Double, shape: IntArray, device: Device): TorchTensorReal =
wrap(randint_double(low.toLong(), high.toLong(), shape.toCValues(), shape.size, device.toInt())!!)
override fun TorchTensorReal.randIntegral(low: Double, high: Double): TorchTensorReal =
wrap(randint_long_like(this.tensorHandle, low.toLong(), high.toLong())!!)
override fun TorchTensorReal.randIntegralAssign(low: Double, high: Double): Unit {
randint_long_like_assign(this.tensorHandle, low.toLong(), high.toLong())
}
} }
public class TorchTensorFloatAlgebra(scope: DeferScope) : public class TorchTensorFloatAlgebra(scope: DeferScope) :
@ -252,26 +376,29 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
override fun wrap(tensorHandle: COpaquePointer): TorchTensorFloat = override fun wrap(tensorHandle: COpaquePointer): TorchTensorFloat =
TorchTensorFloat(scope = scope, tensorHandle = tensorHandle) TorchTensorFloat(scope = scope, tensorHandle = tensorHandle)
override fun number(value: Number): TorchTensorFloat =
full(value.toFloat(), intArrayOf(1), Device.CPU).sum()
override fun TorchTensorFloat.copyToArray(): FloatArray = override fun TorchTensorFloat.copyToArray(): FloatArray =
this.elements().map { it.second }.toList().toFloatArray() this.elements().map { it.second }.toList().toFloatArray()
override fun copyFromArray(array: FloatArray, shape: IntArray, device: TorchDevice): TorchTensorFloat = override fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): TorchTensorFloat =
wrap(from_blob_float(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!) wrap(from_blob_float(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
override fun fromBlob(arrayBlob: CPointer<FloatVar>, shape: IntArray): TorchTensorFloat = override fun fromBlob(arrayBlob: CPointer<FloatVar>, shape: IntArray): TorchTensorFloat =
wrap(from_blob_float(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!) wrap(from_blob_float(arrayBlob, shape.toCValues(), shape.size, Device.CPU.toInt(), false)!!)
override fun TorchTensorFloat.getData(): CPointer<FloatVar> { override fun TorchTensorFloat.getData(): CPointer<FloatVar> {
require(this.device is TorchDevice.TorchCPU) { require(this.device is Device.CPU) {
"This tensor is not on available on CPU" "This tensor is not on available on CPU"
} }
return get_data_float(this.tensorHandle)!! return get_data_float(this.tensorHandle)!!
} }
override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorFloat = override fun randNormal(shape: IntArray, device: Device): TorchTensorFloat =
wrap(randn_float(shape.toCValues(), shape.size, device.toInt())!!) wrap(randn_float(shape.toCValues(), shape.size, device.toInt())!!)
override fun randUniform(shape: IntArray, device: TorchDevice): TorchTensorFloat = override fun randUniform(shape: IntArray, device: Device): TorchTensorFloat =
wrap(rand_float(shape.toCValues(), shape.size, device.toInt())!!) wrap(rand_float(shape.toCValues(), shape.size, device.toInt())!!)
override operator fun Float.plus(other: TorchTensorFloat): TorchTensorFloat = override operator fun Float.plus(other: TorchTensorFloat): TorchTensorFloat =
@ -304,36 +431,52 @@ public class TorchTensorFloatAlgebra(scope: DeferScope) :
times_float_assign(value, this.tensorHandle) times_float_assign(value, this.tensorHandle)
} }
override fun full(value: Float, shape: IntArray, device: TorchDevice): TorchTensorFloat = override fun multiply(a: TorchTensorFloat, k: Number): TorchTensorFloat = a * k.toFloat()
override fun full(value: Float, shape: IntArray, device: Device): TorchTensorFloat =
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!) wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
override fun randIntegral(low: Float, high: Float, shape: IntArray, device: Device): TorchTensorFloat =
wrap(randint_float(low.toLong(), high.toLong(), shape.toCValues(), shape.size, device.toInt())!!)
override fun TorchTensorFloat.randIntegral(low: Float, high: Float): TorchTensorFloat =
wrap(randint_long_like(this.tensorHandle, low.toLong(), high.toLong())!!)
override fun TorchTensorFloat.randIntegralAssign(low: Float, high: Float): Unit {
randint_long_like_assign(this.tensorHandle, low.toLong(), high.toLong())
}
} }
public class TorchTensorLongAlgebra(scope: DeferScope) : public class TorchTensorLongAlgebra(scope: DeferScope) :
TorchTensorRingAlgebra<Long, LongVar, LongArray, TorchTensorLong>(scope) { TorchTensorAlgebra<Long, LongVar, LongArray, TorchTensorLong>(scope) {
override fun wrap(tensorHandle: COpaquePointer): TorchTensorLong = override fun wrap(tensorHandle: COpaquePointer): TorchTensorLong =
TorchTensorLong(scope = scope, tensorHandle = tensorHandle) TorchTensorLong(scope = scope, tensorHandle = tensorHandle)
override fun number(value: Number): TorchTensorLong =
full(value.toLong(), intArrayOf(1), Device.CPU).sum()
override fun TorchTensorLong.copyToArray(): LongArray = override fun TorchTensorLong.copyToArray(): LongArray =
this.elements().map { it.second }.toList().toLongArray() this.elements().map { it.second }.toList().toLongArray()
override fun copyFromArray(array: LongArray, shape: IntArray, device: TorchDevice): TorchTensorLong = override fun copyFromArray(array: LongArray, shape: IntArray, device: Device): TorchTensorLong =
wrap(from_blob_long(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!) wrap(from_blob_long(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
override fun fromBlob(arrayBlob: CPointer<LongVar>, shape: IntArray): TorchTensorLong = override fun fromBlob(arrayBlob: CPointer<LongVar>, shape: IntArray): TorchTensorLong =
wrap(from_blob_long(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!) wrap(from_blob_long(arrayBlob, shape.toCValues(), shape.size, Device.CPU.toInt(), false)!!)
override fun TorchTensorLong.getData(): CPointer<LongVar> { override fun TorchTensorLong.getData(): CPointer<LongVar> {
check(this.device is TorchDevice.TorchCPU) { check(this.device is Device.CPU) {
"This tensor is not on available on CPU" "This tensor is not on available on CPU"
} }
return get_data_long(this.tensorHandle)!! return get_data_long(this.tensorHandle)!!
} }
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: TorchDevice): TorchTensorLong = override fun randIntegral(low: Long, high: Long, shape: IntArray, device: Device): TorchTensorLong =
wrap(randint_long(low, high, shape.toCValues(), shape.size, device.toInt())!!) wrap(randint_long(low, high, shape.toCValues(), shape.size, device.toInt())!!)
override fun TorchTensorLong.randIntegral(low: Long, high: Long): TorchTensorLong = override fun TorchTensorLong.randIntegral(low: Long, high: Long): TorchTensorLong =
wrap(randint_long_like(this.tensorHandle, low, high)!!) wrap(randint_long_like(this.tensorHandle, low, high)!!)
override fun TorchTensorLong.randIntegralAssign(low: Long, high: Long): Unit { override fun TorchTensorLong.randIntegralAssign(low: Long, high: Long): Unit {
randint_long_like_assign(this.tensorHandle, low, high) randint_long_like_assign(this.tensorHandle, low, high)
} }
@ -368,32 +511,37 @@ public class TorchTensorLongAlgebra(scope: DeferScope) :
times_long_assign(value, this.tensorHandle) times_long_assign(value, this.tensorHandle)
} }
override fun full(value: Long, shape: IntArray, device: TorchDevice): TorchTensorLong = override fun multiply(a: TorchTensorLong, k: Number): TorchTensorLong = a * k.toLong()
override fun full(value: Long, shape: IntArray, device: Device): TorchTensorLong =
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!) wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
} }
public class TorchTensorIntAlgebra(scope: DeferScope) : public class TorchTensorIntAlgebra(scope: DeferScope) :
TorchTensorRingAlgebra<Int, IntVar, IntArray, TorchTensorInt>(scope) { TorchTensorAlgebra<Int, IntVar, IntArray, TorchTensorInt>(scope) {
override fun wrap(tensorHandle: COpaquePointer): TorchTensorInt = override fun wrap(tensorHandle: COpaquePointer): TorchTensorInt =
TorchTensorInt(scope = scope, tensorHandle = tensorHandle) TorchTensorInt(scope = scope, tensorHandle = tensorHandle)
override fun number(value: Number): TorchTensorInt =
full(value.toInt(), intArrayOf(1), Device.CPU).sum()
override fun TorchTensorInt.copyToArray(): IntArray = override fun TorchTensorInt.copyToArray(): IntArray =
this.elements().map { it.second }.toList().toIntArray() this.elements().map { it.second }.toList().toIntArray()
override fun copyFromArray(array: IntArray, shape: IntArray, device: TorchDevice): TorchTensorInt = override fun copyFromArray(array: IntArray, shape: IntArray, device: Device): TorchTensorInt =
wrap(from_blob_int(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!) wrap(from_blob_int(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
override fun fromBlob(arrayBlob: CPointer<IntVar>, shape: IntArray): TorchTensorInt = override fun fromBlob(arrayBlob: CPointer<IntVar>, shape: IntArray): TorchTensorInt =
wrap(from_blob_int(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!) wrap(from_blob_int(arrayBlob, shape.toCValues(), shape.size, Device.CPU.toInt(), false)!!)
override fun TorchTensorInt.getData(): CPointer<IntVar> { override fun TorchTensorInt.getData(): CPointer<IntVar> {
require(this.device is TorchDevice.TorchCPU) { require(this.device is Device.CPU) {
"This tensor is not on available on CPU" "This tensor is not on available on CPU"
} }
return get_data_int(this.tensorHandle)!! return get_data_int(this.tensorHandle)!!
} }
override fun randIntegral(low: Int, high: Int, shape: IntArray, device: TorchDevice): TorchTensorInt = override fun randIntegral(low: Int, high: Int, shape: IntArray, device: Device): TorchTensorInt =
wrap(randint_int(low, high, shape.toCValues(), shape.size, device.toInt())!!) wrap(randint_int(low, high, shape.toCValues(), shape.size, device.toInt())!!)
override fun TorchTensorInt.randIntegral(low: Int, high: Int): TorchTensorInt = override fun TorchTensorInt.randIntegral(low: Int, high: Int): TorchTensorInt =
@ -433,7 +581,9 @@ public class TorchTensorIntAlgebra(scope: DeferScope) :
times_int_assign(value, this.tensorHandle) times_int_assign(value, this.tensorHandle)
} }
override fun full(value: Int, shape: IntArray, device: TorchDevice): TorchTensorInt = override fun multiply(a: TorchTensorInt, k: Number): TorchTensorInt = a * k.toInt()
override fun full(value: Int, shape: IntArray, device: Device): TorchTensorInt =
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!) wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
} }

View File

@ -7,15 +7,15 @@ internal fun benchmarkingMatMultDouble(
scale: Int, scale: Int,
numWarmUp: Int, numWarmUp: Int,
numIter: Int, numIter: Int,
device: TorchDevice = TorchDevice.TorchCPU device: Device = Device.CPU
): Unit { ): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
println("Benchmarking $scale x $scale matrices over Double's on $device: ") println("Benchmarking $scale x $scale matrices over Double's on $device: ")
setSeed(SEED) setSeed(SEED)
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device) val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device) val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
repeat(numWarmUp) { lhs dotAssign rhs } repeat(numWarmUp) { lhs.dotAssign(rhs, false) }
val measuredTime = measureTime { repeat(numIter) { lhs dotAssign rhs } } val measuredTime = measureTime { repeat(numIter) { lhs.dotAssign(rhs, false) } }
println(" ${measuredTime / numIter} p.o. with $numIter iterations") println(" ${measuredTime / numIter} p.o. with $numIter iterations")
} }
} }
@ -24,15 +24,15 @@ internal fun benchmarkingMatMultFloat(
scale: Int, scale: Int,
numWarmUp: Int, numWarmUp: Int,
numIter: Int, numIter: Int,
device: TorchDevice = TorchDevice.TorchCPU device: Device = Device.CPU
): Unit { ): Unit {
TorchTensorFloatAlgebra { TorchTensorFloatAlgebra {
println("Benchmarking $scale x $scale matrices over Float's on $device: ") println("Benchmarking $scale x $scale matrices over Float's on $device: ")
setSeed(SEED) setSeed(SEED)
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device) val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device) val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
repeat(numWarmUp) { lhs dotAssign rhs } repeat(numWarmUp) { lhs.dotAssign(rhs, false) }
val measuredTime = measureTime { repeat(numIter) { lhs dotAssign rhs } } val measuredTime = measureTime { repeat(numIter) { lhs.dotAssign(rhs, false) } }
println(" ${measuredTime / numIter} p.o. with $numIter iterations") println(" ${measuredTime / numIter} p.o. with $numIter iterations")
} }
} }

View File

@ -7,7 +7,7 @@ internal fun benchmarkingRandNormal(
samples: Int, samples: Int,
numWarmUp: Int, numWarmUp: Int,
numIter: Int, numIter: Int,
device: TorchDevice = TorchDevice.TorchCPU): Unit device: Device = Device.CPU): Unit
{ {
TorchTensorFloatAlgebra{ TorchTensorFloatAlgebra{
println("Benchmarking generation of $samples Normal samples on $device: ") println("Benchmarking generation of $samples Normal samples on $device: ")
@ -23,7 +23,7 @@ internal fun benchmarkingRandUniform(
samples: Int, samples: Int,
numWarmUp: Int, numWarmUp: Int,
numIter: Int, numIter: Int,
device: TorchDevice = TorchDevice.TorchCPU): Unit device: Device = Device.CPU): Unit
{ {
TorchTensorFloatAlgebra{ TorchTensorFloatAlgebra{
println("Benchmarking generation of $samples Uniform samples on $device: ") println("Benchmarking generation of $samples Uniform samples on $device: ")
@ -40,7 +40,7 @@ internal fun benchmarkingRandIntegral(
samples: Int, samples: Int,
numWarmUp: Int, numWarmUp: Int,
numIter: Int, numIter: Int,
device: TorchDevice = TorchDevice.TorchCPU): Unit device: Device = Device.CPU): Unit
{ {
TorchTensorIntAlgebra { TorchTensorIntAlgebra {
println("Benchmarking generation of $samples integer [0,100] samples on $device: ") println("Benchmarking generation of $samples integer [0,100] samples on $device: ")

View File

@ -2,7 +2,7 @@ package kscience.kmath.torch
import kotlin.test.* import kotlin.test.*
internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit { internal fun testingAutoGrad(dim: Int, device: Device = Device.CPU): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
setSeed(SEED) setSeed(SEED)
val tensorX = randNormal(shape = intArrayOf(dim), device = device) val tensorX = randNormal(shape = intArrayOf(dim), device = device)
@ -26,7 +26,7 @@ internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCP
internal fun testingBatchedAutoGrad(bath: IntArray, internal fun testingBatchedAutoGrad(bath: IntArray,
dim: Int, dim: Int,
device: TorchDevice = TorchDevice.TorchCPU): Unit { device: Device = Device.CPU): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
setSeed(SEED) setSeed(SEED)

View File

@ -3,7 +3,7 @@ package kscience.kmath.torch
import kotlinx.cinterop.* import kotlinx.cinterop.*
import kotlin.test.* import kotlin.test.*
internal fun testingCopyFromArray(device: TorchDevice = TorchDevice.TorchCPU): Unit { internal fun testingCopyFromArray(device: Device = Device.CPU): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
val array = (1..24).map { 10.0 * it * it }.toDoubleArray() val array = (1..24).map { 10.0 * it * it }.toDoubleArray()
val shape = intArrayOf(2, 3, 4) val shape = intArrayOf(2, 3, 4)

View File

@ -6,7 +6,7 @@ import kscience.kmath.structures.Matrix
import kotlin.math.* import kotlin.math.*
import kotlin.test.* import kotlin.test.*
internal fun testingScalarProduct(device: TorchDevice = TorchDevice.TorchCPU): Unit { internal fun testingScalarProduct(device: Device = Device.CPU): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
val lhs = randUniform(shape = intArrayOf(3), device = device) val lhs = randUniform(shape = intArrayOf(3), device = device)
val rhs = randUniform(shape = intArrayOf(3), device = device) val rhs = randUniform(shape = intArrayOf(3), device = device)
@ -19,7 +19,7 @@ internal fun testingScalarProduct(device: TorchDevice = TorchDevice.TorchCPU): U
} }
} }
internal fun testingMatrixMultiplication(device: TorchDevice = TorchDevice.TorchCPU): Unit { internal fun testingMatrixMultiplication(device: Device = Device.CPU): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
setSeed(SEED) setSeed(SEED)
@ -49,7 +49,7 @@ internal fun testingMatrixMultiplication(device: TorchDevice = TorchDevice.Torch
} }
} }
internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU): Unit { internal fun testingLinearStructure(device: Device = Device.CPU): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
val shape = intArrayOf(3) val shape = intArrayOf(3)
val tensorA = full(value = -4.5, shape = shape, device = device) val tensorA = full(value = -4.5, shape = shape, device = device)
@ -84,7 +84,7 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU):
} }
} }
internal fun testingTensorTransformations(device: TorchDevice = TorchDevice.TorchCPU): Unit { internal fun testingTensorTransformations(device: Device = Device.CPU): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
setSeed(SEED) setSeed(SEED)
val tensor = randNormal(shape = intArrayOf(3, 3), device = device) val tensor = randNormal(shape = intArrayOf(3, 3), device = device)
@ -102,21 +102,21 @@ internal fun testingTensorTransformations(device: TorchDevice = TorchDevice.Torc
} }
} }
internal fun testingBatchedSVD(device: TorchDevice = TorchDevice.TorchCPU): Unit { internal fun testingBatchedSVD(device: Device = Device.CPU): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
val tensor = randNormal(shape = intArrayOf(7, 5, 3), device = device) val tensor = randNormal(shape = intArrayOf(7, 5, 3), device = device)
val (tensorU, tensorS, tensorV) = tensor.svd() val (tensorU, tensorS, tensorV) = tensor.svd()
val error = tensor - (tensorU dot (diagEmbed(tensorS) dot tensorV.transpose(-2,-1))) val error = tensor - (tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2,-1)))
assertTrue(error.abs().sum().value() < TOLERANCE) assertTrue(error.abs().sum().value() < TOLERANCE)
} }
} }
internal fun testingBatchedSymEig(device: TorchDevice = TorchDevice.TorchCPU): Unit { internal fun testingBatchedSymEig(device: Device = Device.CPU): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
val tensor = randNormal(shape = intArrayOf(5,5), device = device) val tensor = randNormal(shape = intArrayOf(5,5), device = device)
val tensorSigma = tensor + tensor.transpose(-2,-1) val tensorSigma = tensor + tensor.transpose(-2,-1)
val (tensorS, tensorV) = tensorSigma.symEig() val (tensorS, tensorV) = tensorSigma.symEig()
val error = tensorSigma - (tensorV dot (diagEmbed(tensorS) dot tensorV.transpose(-2,-1))) val error = tensorSigma - (tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2,-1)))
assertTrue(error.abs().sum().value() < TOLERANCE) assertTrue(error.abs().sum().value() < TOLERANCE)
} }
} }

View File

@ -6,7 +6,7 @@ import kotlin.test.*
internal val SEED = 987654 internal val SEED = 987654
internal val TOLERANCE = 1e-6 internal val TOLERANCE = 1e-6
internal fun testingSetSeed(device: TorchDevice = TorchDevice.TorchCPU): Unit { internal fun testingSetSeed(device: Device = Device.CPU): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
setSeed(SEED) setSeed(SEED)
val normal = randNormal(IntArray(0), device = device).value() val normal = randNormal(IntArray(0), device = device).value()