Tensor transformations first examples

This commit is contained in:
rgrit91 2021-01-13 12:55:44 +00:00
parent 39f3a87bbd
commit ca3cca65ef
8 changed files with 213 additions and 113 deletions

View File

@ -73,20 +73,20 @@ extern "C"
TorchTensorHandle times_long(long value, TorchTensorHandle other); TorchTensorHandle times_long(long value, TorchTensorHandle other);
TorchTensorHandle times_int(int value, TorchTensorHandle other); TorchTensorHandle times_int(int value, TorchTensorHandle other);
void times_assign_double(double value, TorchTensorHandle other); void times_double_assign(double value, TorchTensorHandle other);
void times_assign_float(float value, TorchTensorHandle other); void times_float_assign(float value, TorchTensorHandle other);
void times_assign_long(long value, TorchTensorHandle other); void times_long_assign(long value, TorchTensorHandle other);
void times_assign_int(int value, TorchTensorHandle other); void times_int_assign(int value, TorchTensorHandle other);
TorchTensorHandle plus_double(double value, TorchTensorHandle other); TorchTensorHandle plus_double(double value, TorchTensorHandle other);
TorchTensorHandle plus_float(float value, TorchTensorHandle other); TorchTensorHandle plus_float(float value, TorchTensorHandle other);
TorchTensorHandle plus_long(long value, TorchTensorHandle other); TorchTensorHandle plus_long(long value, TorchTensorHandle other);
TorchTensorHandle plus_int(int value, TorchTensorHandle other); TorchTensorHandle plus_int(int value, TorchTensorHandle other);
void plus_assign_double(double value, TorchTensorHandle other); void plus_double_assign(double value, TorchTensorHandle other);
void plus_assign_float(float value, TorchTensorHandle other); void plus_float_assign(float value, TorchTensorHandle other);
void plus_assign_long(long value, TorchTensorHandle other); void plus_long_assign(long value, TorchTensorHandle other);
void plus_assign_int(int value, TorchTensorHandle other); void plus_int_assign(int value, TorchTensorHandle other);
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
@ -96,10 +96,22 @@ extern "C"
void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs); TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs); void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle unary_minus(TorchTensorHandle tensor); TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle);
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle); TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle);
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle); void abs_tensor_assign(TorchTensorHandle tensor_handle);
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j); TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j);
void transpose_tensor_assign(TorchTensorHandle tensor_handle, int i, int j);
TorchTensorHandle exp_tensor(TorchTensorHandle tensor_handle);
void exp_tensor_assign(TorchTensorHandle tensor_handle);
TorchTensorHandle log_tensor(TorchTensorHandle tensor_handle);
void log_tensor_assign(TorchTensorHandle tensor_handle);
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
void sum_tensor_assign(TorchTensorHandle tensor_handle);
bool requires_grad(TorchTensorHandle tensor_handle); bool requires_grad(TorchTensorHandle tensor_handle);
void requires_grad_(TorchTensorHandle tensor_handle, bool status); void requires_grad_(TorchTensorHandle tensor_handle, bool status);

View File

@ -219,19 +219,19 @@ TorchTensorHandle plus_int(int value, TorchTensorHandle other)
{ {
return new torch::Tensor(ctorch::cast(other) + value); return new torch::Tensor(ctorch::cast(other) + value);
} }
void plus_assign_double(double value, TorchTensorHandle other) void plus_double_assign(double value, TorchTensorHandle other)
{ {
ctorch::cast(other) += value; ctorch::cast(other) += value;
} }
void plus_assign_float(float value, TorchTensorHandle other) void plus_float_assign(float value, TorchTensorHandle other)
{ {
ctorch::cast(other) += value; ctorch::cast(other) += value;
} }
void plus_assign_long(long value, TorchTensorHandle other) void plus_long_assign(long value, TorchTensorHandle other)
{ {
ctorch::cast(other) += value; ctorch::cast(other) += value;
} }
void plus_assign_int(int value, TorchTensorHandle other) void plus_int_assign(int value, TorchTensorHandle other)
{ {
ctorch::cast(other) += value; ctorch::cast(other) += value;
} }
@ -252,19 +252,19 @@ TorchTensorHandle times_int(int value, TorchTensorHandle other)
{ {
return new torch::Tensor(value * ctorch::cast(other)); return new torch::Tensor(value * ctorch::cast(other));
} }
void times_assign_double(double value, TorchTensorHandle other) void times_double_assign(double value, TorchTensorHandle other)
{ {
ctorch::cast(other) *= value; ctorch::cast(other) *= value;
} }
void times_assign_float(float value, TorchTensorHandle other) void times_float_assign(float value, TorchTensorHandle other)
{ {
ctorch::cast(other) *= value; ctorch::cast(other) *= value;
} }
void times_assign_long(long value, TorchTensorHandle other) void times_long_assign(long value, TorchTensorHandle other)
{ {
ctorch::cast(other) *= value; ctorch::cast(other) *= value;
} }
void times_assign_int(int value, TorchTensorHandle other) void times_int_assign(int value, TorchTensorHandle other)
{ {
ctorch::cast(other) *= value; ctorch::cast(other) *= value;
} }
@ -301,21 +301,54 @@ void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{ {
ctorch::cast(lhs) /= ctorch::cast(rhs); ctorch::cast(lhs) /= ctorch::cast(rhs);
} }
TorchTensorHandle unary_minus(TorchTensorHandle tensor) TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle)
{ {
return new torch::Tensor(-ctorch::cast(tensor)); return new torch::Tensor(-ctorch::cast(tensor_handle));
} }
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle) TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle)
{ {
return new torch::Tensor(ctorch::cast(tensor_handle).abs()); return new torch::Tensor(ctorch::cast(tensor_handle).abs());
} }
void abs_tensor_assign(TorchTensorHandle tensor_handle)
{
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).abs();
}
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j)
{
return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j));
}
void transpose_tensor_assign(TorchTensorHandle tensor_handle, int i, int j)
{
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).transpose(i, j);
}
TorchTensorHandle exp_tensor(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(ctorch::cast(tensor_handle).exp());
}
void exp_tensor_assign(TorchTensorHandle tensor_handle)
{
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).exp();
}
TorchTensorHandle log_tensor(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(ctorch::cast(tensor_handle).log());
}
void log_tensor_assign(TorchTensorHandle tensor_handle)
{
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).log();
}
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle) TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle)
{ {
return new torch::Tensor(ctorch::cast(tensor_handle).sum()); return new torch::Tensor(ctorch::cast(tensor_handle).sum());
} }
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j) void sum_tensor_assign(TorchTensorHandle tensor_handle)
{ {
return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j)); ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).sum();
} }
bool requires_grad(TorchTensorHandle tensor_handle) bool requires_grad(TorchTensorHandle tensor_handle)

View File

@ -0,0 +1,9 @@
package kscience.kmath.torch
import kotlin.test.*
internal class TestAutogradGPU {
@Test
fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0))
}

View File

@ -18,5 +18,7 @@ class TestTorchTensorAlgebraGPU {
testingLinearStructure(device = TorchDevice.TorchCUDA(0)) testingLinearStructure(device = TorchDevice.TorchCUDA(0))
@Test @Test
fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0)) fun testTensorTransformations() =
testingTensorTransformations(device = TorchDevice.TorchCUDA(0))
} }

View File

@ -16,7 +16,6 @@ public sealed class TorchTensor<T> constructor(
private fun close(): Unit = dispose_tensor(tensorHandle) private fun close(): Unit = dispose_tensor(tensorHandle)
protected abstract fun item(): T protected abstract fun item(): T
internal abstract fun wrap(outScope: DeferScope, outTensorHandle: COpaquePointer): TorchTensor<T>
override val dimension: Int get() = get_dim(tensorHandle) override val dimension: Int get() = get_dim(tensorHandle)
override val shape: IntArray override val shape: IntArray
@ -50,18 +49,6 @@ public sealed class TorchTensor<T> constructor(
return item() return item()
} }
public fun copy(): TorchTensor<T> =
wrap(
outScope = scope,
outTensorHandle = copy_tensor(tensorHandle)!!
)
public fun copyToDevice(device: TorchDevice): TorchTensor<T> =
wrap(
outScope = scope,
outTensorHandle = copy_to_device(tensorHandle, device.toInt())!!
)
public var requiresGrad: Boolean public var requiresGrad: Boolean
get() = requires_grad(tensorHandle) get() = requires_grad(tensorHandle)
set(value) = requires_grad_(tensorHandle, value) set(value) = requires_grad_(tensorHandle, value)
@ -76,9 +63,6 @@ public class TorchTensorReal internal constructor(
tensorHandle: COpaquePointer tensorHandle: COpaquePointer
) : TorchTensor<Double>(scope, tensorHandle) { ) : TorchTensor<Double>(scope, tensorHandle) {
override fun item(): Double = get_item_double(tensorHandle) override fun item(): Double = get_item_double(tensorHandle)
override fun wrap(outScope: DeferScope, outTensorHandle: COpaquePointer
): TorchTensorReal = TorchTensorReal(scope = outScope, tensorHandle = outTensorHandle)
override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues()) override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues())
override fun set(index: IntArray, value: Double) { override fun set(index: IntArray, value: Double) {
set_double(tensorHandle, index.toCValues(), value) set_double(tensorHandle, index.toCValues(), value)

View File

@ -3,93 +3,129 @@ package kscience.kmath.torch
import kotlinx.cinterop.* import kotlinx.cinterop.*
import kscience.kmath.ctorch.* import kscience.kmath.ctorch.*
public sealed class TorchTensorAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayType> constructor( public sealed class TorchTensorAlgebra<
T,
TVar : CPrimitiveVar,
PrimitiveArrayType,
TorchTensorType : TorchTensor<T>> constructor(
internal val scope: DeferScope internal val scope: DeferScope
) { ) {
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensor<T> internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensorType
public abstract fun copyFromArray( public abstract fun copyFromArray(
array: PrimitiveArrayType, array: PrimitiveArrayType,
shape: IntArray, shape: IntArray,
device: TorchDevice = TorchDevice.TorchCPU device: TorchDevice = TorchDevice.TorchCPU
): TorchTensor<T> ): TorchTensorType
public abstract fun TorchTensor<T>.copyToArray(): PrimitiveArrayType
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensor<T> public abstract fun TorchTensorType.copyToArray(): PrimitiveArrayType
public abstract fun TorchTensor<T>.getData(): CPointer<TVar>
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor<T> public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType
public abstract fun TorchTensorType.getData(): CPointer<TVar>
public abstract operator fun T.plus(other: TorchTensor<T>): TorchTensor<T> public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensorType
public abstract operator fun TorchTensor<T>.plus(value: T): TorchTensor<T>
public abstract operator fun TorchTensor<T>.plusAssign(value: T): Unit
public abstract operator fun T.minus(other: TorchTensor<T>): TorchTensor<T>
public abstract operator fun TorchTensor<T>.minus(value: T): TorchTensor<T>
public abstract operator fun TorchTensor<T>.minusAssign(value: T): Unit
public abstract operator fun T.times(other: TorchTensor<T>): TorchTensor<T>
public abstract operator fun TorchTensor<T>.times(value: T): TorchTensor<T>
public abstract operator fun TorchTensor<T>.timesAssign(value: T): Unit
public operator fun TorchTensor<T>.times(other: TorchTensor<T>): TorchTensor<T> = public abstract operator fun T.plus(other: TorchTensorType): TorchTensorType
public abstract operator fun TorchTensorType.plus(value: T): TorchTensorType
public abstract operator fun TorchTensorType.plusAssign(value: T): Unit
public abstract operator fun T.minus(other: TorchTensorType): TorchTensorType
public abstract operator fun TorchTensorType.minus(value: T): TorchTensorType
public abstract operator fun TorchTensorType.minusAssign(value: T): Unit
public abstract operator fun T.times(other: TorchTensorType): TorchTensorType
public abstract operator fun TorchTensorType.times(value: T): TorchTensorType
public abstract operator fun TorchTensorType.timesAssign(value: T): Unit
public operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType =
wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!) wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!)
public operator fun TorchTensor<T>.timesAssign(other: TorchTensor<T>): Unit {
public operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
times_tensor_assign(this.tensorHandle, other.tensorHandle) times_tensor_assign(this.tensorHandle, other.tensorHandle)
} }
public operator fun TorchTensor<T>.div(other: TorchTensor<T>): TorchTensor<T> =
public operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType =
wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!) wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!)
public operator fun TorchTensor<T>.divAssign(other: TorchTensor<T>): Unit {
public operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
div_tensor_assign(this.tensorHandle, other.tensorHandle) div_tensor_assign(this.tensorHandle, other.tensorHandle)
} }
public infix fun TorchTensor<T>.dot(other: TorchTensor<T>): TorchTensor<T> = public infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType =
wrap(matmul(this.tensorHandle, other.tensorHandle)!!) wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
public infix fun TorchTensor<T>.dotAssign(other: TorchTensor<T>): Unit { public infix fun TorchTensorType.dotAssign(other: TorchTensorType): Unit {
matmul_assign(this.tensorHandle, other.tensorHandle) matmul_assign(this.tensorHandle, other.tensorHandle)
} }
public infix fun TorchTensor<T>.dotRightAssign(other: TorchTensor<T>): Unit { public infix fun TorchTensorType.dotRightAssign(other: TorchTensorType): Unit {
matmul_right_assign(this.tensorHandle, other.tensorHandle) matmul_right_assign(this.tensorHandle, other.tensorHandle)
} }
public operator fun TorchTensor<T>.plus(other: TorchTensor<T>): TorchTensor<T> = public operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType =
wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!) wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!)
public operator fun TorchTensor<T>.plusAssign(other: TorchTensor<T>): Unit { public operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
plus_tensor_assign(this.tensorHandle, other.tensorHandle) plus_tensor_assign(this.tensorHandle, other.tensorHandle)
} }
public operator fun TorchTensor<T>.minus(other: TorchTensor<T>): TorchTensor<T> = public operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType =
wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!) wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!)
public operator fun TorchTensor<T>.minusAssign(other: TorchTensor<T>): Unit { public operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
minus_tensor_assign(this.tensorHandle, other.tensorHandle) minus_tensor_assign(this.tensorHandle, other.tensorHandle)
} }
public operator fun TorchTensor<T>.unaryMinus(): TorchTensor<T> =
public operator fun TorchTensorType.unaryMinus(): TorchTensorType =
wrap(unary_minus(this.tensorHandle)!!) wrap(unary_minus(this.tensorHandle)!!)
public fun TorchTensor<T>.abs(): TorchTensor<T> = wrap(abs_tensor(tensorHandle)!!) public fun TorchTensorType.abs(): TorchTensorType = wrap(abs_tensor(tensorHandle)!!)
public fun TorchTensorType.absAssign(): Unit {
abs_tensor_assign(tensorHandle)
}
public fun TorchTensor<T>.sum(): TorchTensor<T> = wrap(sum_tensor(tensorHandle)!!) public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType =
wrap(transpose_tensor(tensorHandle, i, j)!!)
public fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
transpose_tensor_assign(tensorHandle, i, j)
}
public fun TorchTensor<T>.transpose(i: Int, j: Int): TorchTensor<T> = public fun TorchTensorType.sum(): TorchTensorType = wrap(sum_tensor(tensorHandle)!!)
wrap(transpose_tensor(tensorHandle, i , j)!!) public fun TorchTensorType.sumAssign(): Unit {
sum_tensor_assign(tensorHandle)
}
public infix fun TorchTensor<T>.grad(variable: TorchTensor<T>): TorchTensor<T> = public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!) wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
public fun TorchTensorType.copy(): TorchTensorType =
wrap(tensorHandle = copy_tensor(this.tensorHandle)!!)
public fun TorchTensorType.copyToDevice(device: TorchDevice): TorchTensorType =
wrap(tensorHandle = copy_to_device(this.tensorHandle, device.toInt())!!)
} }
public sealed class TorchTensorFieldAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayType>(scope: DeferScope) : public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
TorchTensorAlgebra<T, TVar, PrimitiveArrayType>(scope) { PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) :
public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T> TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope) {
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T> public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType
public fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
public fun TorchTensorType.expAssign(): Unit {
exp_tensor_assign(tensorHandle)
}
public fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!)
public fun TorchTensorType.logAssign(): Unit {
log_tensor_assign(tensorHandle)
}
} }
public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra<Double, DoubleVar, DoubleArray>(scope) { public class TorchTensorRealAlgebra(scope: DeferScope) :
TorchTensorFieldAlgebra<Double, DoubleVar, DoubleArray, TorchTensorReal>(scope) {
override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal = override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal =
TorchTensorReal(scope = scope, tensorHandle = tensorHandle) TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
override fun TorchTensor<Double>.copyToArray(): DoubleArray = override fun TorchTensorReal.copyToArray(): DoubleArray =
this.elements().map { it.second }.toList().toDoubleArray() this.elements().map { it.second }.toList().toDoubleArray()
override fun copyFromArray( override fun copyFromArray(
@ -120,8 +156,8 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
)!! )!!
) )
override fun TorchTensor<Double>.getData(): CPointer<DoubleVar> { override fun TorchTensorReal.getData(): CPointer<DoubleVar> {
require(this.device is TorchDevice.TorchCPU){ require(this.device is TorchDevice.TorchCPU) {
"This tensor is not on available on CPU" "This tensor is not on available on CPU"
} }
return get_data_double(this.tensorHandle)!! return get_data_double(this.tensorHandle)!!
@ -137,46 +173,46 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!! tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!!
) )
override operator fun Double.plus(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal( override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal = TorchTensorReal(
scope = scope, scope = scope,
tensorHandle = plus_double(this, other.tensorHandle)!! tensorHandle = plus_double(this, other.tensorHandle)!!
) )
override fun TorchTensor<Double>.plus(value: Double): TorchTensorReal = TorchTensorReal( override fun TorchTensorReal.plus(value: Double): TorchTensorReal = TorchTensorReal(
scope = scope, scope = scope,
tensorHandle = plus_double(value, this.tensorHandle)!! tensorHandle = plus_double(value, this.tensorHandle)!!
) )
override fun TorchTensor<Double>.plusAssign(value: Double): Unit { override fun TorchTensorReal.plusAssign(value: Double): Unit {
plus_assign_double(value, this.tensorHandle) plus_double_assign(value, this.tensorHandle)
} }
override operator fun Double.minus(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal( override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal = TorchTensorReal(
scope = scope, scope = scope,
tensorHandle = plus_double(-this, other.tensorHandle)!! tensorHandle = plus_double(-this, other.tensorHandle)!!
) )
override fun TorchTensor<Double>.minus(value: Double): TorchTensorReal = TorchTensorReal( override fun TorchTensorReal.minus(value: Double): TorchTensorReal = TorchTensorReal(
scope = scope, scope = scope,
tensorHandle = plus_double(-value, this.tensorHandle)!! tensorHandle = plus_double(-value, this.tensorHandle)!!
) )
override fun TorchTensor<Double>.minusAssign(value: Double): Unit { override fun TorchTensorReal.minusAssign(value: Double): Unit {
plus_assign_double(-value, this.tensorHandle) plus_double_assign(-value, this.tensorHandle)
} }
override operator fun Double.times(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal( override operator fun Double.times(other: TorchTensorReal): TorchTensorReal = TorchTensorReal(
scope = scope, scope = scope,
tensorHandle = times_double(this, other.tensorHandle)!! tensorHandle = times_double(this, other.tensorHandle)!!
) )
override fun TorchTensor<Double>.times(value: Double): TorchTensorReal = TorchTensorReal( override fun TorchTensorReal.times(value: Double): TorchTensorReal = TorchTensorReal(
scope = scope, scope = scope,
tensorHandle = times_double(value, this.tensorHandle)!! tensorHandle = times_double(value, this.tensorHandle)!!
) )
override fun TorchTensor<Double>.timesAssign(value: Double): Unit { override fun TorchTensorReal.timesAssign(value: Double): Unit {
times_assign_double(value, this.tensorHandle) times_double_assign(value, this.tensorHandle)
} }

View File

@ -0,0 +1,28 @@
package kscience.kmath.torch
import kotlin.test.*
internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit {
TorchTensorRealAlgebra {
setSeed(SEED)
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
tensorX.requiresGrad = true
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
val tensorSigma = randFeatures + randFeatures.transpose(0,1)
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
val expressionAtX =
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
val gradientAtX = expressionAtX grad tensorX
val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
assertTrue(error < TOLERANCE)
}
}
internal class TestAutograd {
@Test
fun testAutoGrad() = testingAutoGrad(dim = 100)
}

View File

@ -3,8 +3,7 @@ package kscience.kmath.torch
import kscience.kmath.linear.RealMatrixContext import kscience.kmath.linear.RealMatrixContext
import kscience.kmath.operations.invoke import kscience.kmath.operations.invoke
import kscience.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import kotlin.math.abs import kotlin.math.*
import kotlin.math.exp
import kotlin.test.* import kotlin.test.*
internal fun testingScalarProduct(device: TorchDevice = TorchDevice.TorchCPU): Unit { internal fun testingScalarProduct(device: TorchDevice = TorchDevice.TorchCPU): Unit {
@ -40,7 +39,7 @@ internal fun testingMatrixMultiplication(device: TorchDevice = TorchDevice.Torch
lhsTensorCopy dotAssign rhsTensor lhsTensorCopy dotAssign rhsTensor
lhsTensor dotRightAssign rhsTensorCopy lhsTensor dotRightAssign rhsTensorCopy
var error: Double = 0.0 var error = 0.0
product.elements().forEach { product.elements().forEach {
error += abs(expected[it.first] - it.second) + error += abs(expected[it.first] - it.second) +
abs(expected[it.first] - lhsTensorCopy[it.first]) + abs(expected[it.first] - lhsTensorCopy[it.first]) +
@ -85,27 +84,24 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU):
} }
} }
internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit { internal fun testingTensorTransformations(device: TorchDevice = TorchDevice.TorchCPU): Unit {
TorchTensorRealAlgebra { TorchTensorRealAlgebra {
setSeed(SEED) setSeed(SEED)
val tensorX = randNormal(shape = intArrayOf(dim), device = device) val tensor = randNormal(shape = intArrayOf(3, 3), device = device)
tensorX.requiresGrad = true val result = tensor.exp().log()
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device) val assignResult = tensor.copy()
val tensorSigma = randFeatures + randFeatures.transpose(0,1) assignResult.transposeAssign(0,1)
val tensorMu = randNormal(shape = intArrayOf(dim), device = device) assignResult.expAssign()
assignResult.logAssign()
val expressionAtX = assignResult.transposeAssign(0,1)
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9 val error = tensor - result
error.absAssign()
val gradientAtX = expressionAtX grad tensorX error.sumAssign()
val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu error += (tensor - assignResult).abs().sum()
assertTrue (error.value()< TOLERANCE)
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
assertTrue(error < TOLERANCE)
} }
} }
internal class TestTorchTensorAlgebra { internal class TestTorchTensorAlgebra {
@Test @Test
@ -118,6 +114,6 @@ internal class TestTorchTensorAlgebra {
fun testLinearStructure() = testingLinearStructure() fun testLinearStructure() = testingLinearStructure()
@Test @Test
fun testAutoGrad() = testingAutoGrad(dim = 100) fun testTensorTransformations() = testingTensorTransformations()
} }