Tensor transformations first examples

This commit is contained in:
rgrit91 2021-01-13 12:55:44 +00:00
parent 39f3a87bbd
commit ca3cca65ef
8 changed files with 213 additions and 113 deletions

View File

@ -73,20 +73,20 @@ extern "C"
TorchTensorHandle times_long(long value, TorchTensorHandle other);
TorchTensorHandle times_int(int value, TorchTensorHandle other);
void times_assign_double(double value, TorchTensorHandle other);
void times_assign_float(float value, TorchTensorHandle other);
void times_assign_long(long value, TorchTensorHandle other);
void times_assign_int(int value, TorchTensorHandle other);
void times_double_assign(double value, TorchTensorHandle other);
void times_float_assign(float value, TorchTensorHandle other);
void times_long_assign(long value, TorchTensorHandle other);
void times_int_assign(int value, TorchTensorHandle other);
TorchTensorHandle plus_double(double value, TorchTensorHandle other);
TorchTensorHandle plus_float(float value, TorchTensorHandle other);
TorchTensorHandle plus_long(long value, TorchTensorHandle other);
TorchTensorHandle plus_int(int value, TorchTensorHandle other);
void plus_assign_double(double value, TorchTensorHandle other);
void plus_assign_float(float value, TorchTensorHandle other);
void plus_assign_long(long value, TorchTensorHandle other);
void plus_assign_int(int value, TorchTensorHandle other);
void plus_double_assign(double value, TorchTensorHandle other);
void plus_float_assign(float value, TorchTensorHandle other);
void plus_long_assign(long value, TorchTensorHandle other);
void plus_int_assign(int value, TorchTensorHandle other);
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
@ -96,10 +96,22 @@ extern "C"
void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle unary_minus(TorchTensorHandle tensor);
TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle);
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle);
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
void abs_tensor_assign(TorchTensorHandle tensor_handle);
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j);
void transpose_tensor_assign(TorchTensorHandle tensor_handle, int i, int j);
TorchTensorHandle exp_tensor(TorchTensorHandle tensor_handle);
void exp_tensor_assign(TorchTensorHandle tensor_handle);
TorchTensorHandle log_tensor(TorchTensorHandle tensor_handle);
void log_tensor_assign(TorchTensorHandle tensor_handle);
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
void sum_tensor_assign(TorchTensorHandle tensor_handle);
bool requires_grad(TorchTensorHandle tensor_handle);
void requires_grad_(TorchTensorHandle tensor_handle, bool status);

View File

@ -219,19 +219,19 @@ TorchTensorHandle plus_int(int value, TorchTensorHandle other)
{
return new torch::Tensor(ctorch::cast(other) + value);
}
void plus_assign_double(double value, TorchTensorHandle other)
void plus_double_assign(double value, TorchTensorHandle other)
{
ctorch::cast(other) += value;
}
void plus_assign_float(float value, TorchTensorHandle other)
void plus_float_assign(float value, TorchTensorHandle other)
{
ctorch::cast(other) += value;
}
void plus_assign_long(long value, TorchTensorHandle other)
void plus_long_assign(long value, TorchTensorHandle other)
{
ctorch::cast(other) += value;
}
void plus_assign_int(int value, TorchTensorHandle other)
void plus_int_assign(int value, TorchTensorHandle other)
{
ctorch::cast(other) += value;
}
@ -252,19 +252,19 @@ TorchTensorHandle times_int(int value, TorchTensorHandle other)
{
return new torch::Tensor(value * ctorch::cast(other));
}
void times_assign_double(double value, TorchTensorHandle other)
void times_double_assign(double value, TorchTensorHandle other)
{
ctorch::cast(other) *= value;
}
void times_assign_float(float value, TorchTensorHandle other)
void times_float_assign(float value, TorchTensorHandle other)
{
ctorch::cast(other) *= value;
}
void times_assign_long(long value, TorchTensorHandle other)
void times_long_assign(long value, TorchTensorHandle other)
{
ctorch::cast(other) *= value;
}
void times_assign_int(int value, TorchTensorHandle other)
void times_int_assign(int value, TorchTensorHandle other)
{
ctorch::cast(other) *= value;
}
@ -301,21 +301,54 @@ void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
ctorch::cast(lhs) /= ctorch::cast(rhs);
}
TorchTensorHandle unary_minus(TorchTensorHandle tensor)
TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(-ctorch::cast(tensor));
return new torch::Tensor(-ctorch::cast(tensor_handle));
}
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(ctorch::cast(tensor_handle).abs());
}
void abs_tensor_assign(TorchTensorHandle tensor_handle)
{
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).abs();
}
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j)
{
return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j));
}
void transpose_tensor_assign(TorchTensorHandle tensor_handle, int i, int j)
{
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).transpose(i, j);
}
TorchTensorHandle exp_tensor(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(ctorch::cast(tensor_handle).exp());
}
void exp_tensor_assign(TorchTensorHandle tensor_handle)
{
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).exp();
}
TorchTensorHandle log_tensor(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(ctorch::cast(tensor_handle).log());
}
void log_tensor_assign(TorchTensorHandle tensor_handle)
{
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).log();
}
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(ctorch::cast(tensor_handle).sum());
}
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j)
void sum_tensor_assign(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j));
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).sum();
}
bool requires_grad(TorchTensorHandle tensor_handle)

View File

@ -0,0 +1,9 @@
package kscience.kmath.torch
import kotlin.test.*
internal class TestAutogradGPU {
@Test
fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0))
}

View File

@ -18,5 +18,7 @@ class TestTorchTensorAlgebraGPU {
testingLinearStructure(device = TorchDevice.TorchCUDA(0))
@Test
fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0))
fun testTensorTransformations() =
testingTensorTransformations(device = TorchDevice.TorchCUDA(0))
}

View File

@ -16,7 +16,6 @@ public sealed class TorchTensor<T> constructor(
private fun close(): Unit = dispose_tensor(tensorHandle)
protected abstract fun item(): T
internal abstract fun wrap(outScope: DeferScope, outTensorHandle: COpaquePointer): TorchTensor<T>
override val dimension: Int get() = get_dim(tensorHandle)
override val shape: IntArray
@ -50,18 +49,6 @@ public sealed class TorchTensor<T> constructor(
return item()
}
public fun copy(): TorchTensor<T> =
wrap(
outScope = scope,
outTensorHandle = copy_tensor(tensorHandle)!!
)
public fun copyToDevice(device: TorchDevice): TorchTensor<T> =
wrap(
outScope = scope,
outTensorHandle = copy_to_device(tensorHandle, device.toInt())!!
)
public var requiresGrad: Boolean
get() = requires_grad(tensorHandle)
set(value) = requires_grad_(tensorHandle, value)
@ -76,9 +63,6 @@ public class TorchTensorReal internal constructor(
tensorHandle: COpaquePointer
) : TorchTensor<Double>(scope, tensorHandle) {
override fun item(): Double = get_item_double(tensorHandle)
override fun wrap(outScope: DeferScope, outTensorHandle: COpaquePointer
): TorchTensorReal = TorchTensorReal(scope = outScope, tensorHandle = outTensorHandle)
override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues())
override fun set(index: IntArray, value: Double) {
set_double(tensorHandle, index.toCValues(), value)

View File

@ -3,93 +3,129 @@ package kscience.kmath.torch
import kotlinx.cinterop.*
import kscience.kmath.ctorch.*
public sealed class TorchTensorAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayType> constructor(
public sealed class TorchTensorAlgebra<
T,
TVar : CPrimitiveVar,
PrimitiveArrayType,
TorchTensorType : TorchTensor<T>> constructor(
internal val scope: DeferScope
) {
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensor<T>
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensorType
public abstract fun copyFromArray(
array: PrimitiveArrayType,
shape: IntArray,
device: TorchDevice = TorchDevice.TorchCPU
): TorchTensor<T>
public abstract fun TorchTensor<T>.copyToArray(): PrimitiveArrayType
): TorchTensorType
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensor<T>
public abstract fun TorchTensor<T>.getData(): CPointer<TVar>
public abstract fun TorchTensorType.copyToArray(): PrimitiveArrayType
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor<T>
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType
public abstract fun TorchTensorType.getData(): CPointer<TVar>
public abstract operator fun T.plus(other: TorchTensor<T>): TorchTensor<T>
public abstract operator fun TorchTensor<T>.plus(value: T): TorchTensor<T>
public abstract operator fun TorchTensor<T>.plusAssign(value: T): Unit
public abstract operator fun T.minus(other: TorchTensor<T>): TorchTensor<T>
public abstract operator fun TorchTensor<T>.minus(value: T): TorchTensor<T>
public abstract operator fun TorchTensor<T>.minusAssign(value: T): Unit
public abstract operator fun T.times(other: TorchTensor<T>): TorchTensor<T>
public abstract operator fun TorchTensor<T>.times(value: T): TorchTensor<T>
public abstract operator fun TorchTensor<T>.timesAssign(value: T): Unit
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensorType
public operator fun TorchTensor<T>.times(other: TorchTensor<T>): TorchTensor<T> =
public abstract operator fun T.plus(other: TorchTensorType): TorchTensorType
public abstract operator fun TorchTensorType.plus(value: T): TorchTensorType
public abstract operator fun TorchTensorType.plusAssign(value: T): Unit
public abstract operator fun T.minus(other: TorchTensorType): TorchTensorType
public abstract operator fun TorchTensorType.minus(value: T): TorchTensorType
public abstract operator fun TorchTensorType.minusAssign(value: T): Unit
public abstract operator fun T.times(other: TorchTensorType): TorchTensorType
public abstract operator fun TorchTensorType.times(value: T): TorchTensorType
public abstract operator fun TorchTensorType.timesAssign(value: T): Unit
public operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType =
wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!)
public operator fun TorchTensor<T>.timesAssign(other: TorchTensor<T>): Unit {
public operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
times_tensor_assign(this.tensorHandle, other.tensorHandle)
}
public operator fun TorchTensor<T>.div(other: TorchTensor<T>): TorchTensor<T> =
public operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType =
wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!)
public operator fun TorchTensor<T>.divAssign(other: TorchTensor<T>): Unit {
public operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
div_tensor_assign(this.tensorHandle, other.tensorHandle)
}
public infix fun TorchTensor<T>.dot(other: TorchTensor<T>): TorchTensor<T> =
public infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType =
wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
public infix fun TorchTensor<T>.dotAssign(other: TorchTensor<T>): Unit {
public infix fun TorchTensorType.dotAssign(other: TorchTensorType): Unit {
matmul_assign(this.tensorHandle, other.tensorHandle)
}
public infix fun TorchTensor<T>.dotRightAssign(other: TorchTensor<T>): Unit {
public infix fun TorchTensorType.dotRightAssign(other: TorchTensorType): Unit {
matmul_right_assign(this.tensorHandle, other.tensorHandle)
}
public operator fun TorchTensor<T>.plus(other: TorchTensor<T>): TorchTensor<T> =
public operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType =
wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!)
public operator fun TorchTensor<T>.plusAssign(other: TorchTensor<T>): Unit {
public operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
plus_tensor_assign(this.tensorHandle, other.tensorHandle)
}
public operator fun TorchTensor<T>.minus(other: TorchTensor<T>): TorchTensor<T> =
public operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType =
wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!)
public operator fun TorchTensor<T>.minusAssign(other: TorchTensor<T>): Unit {
public operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
}
public operator fun TorchTensor<T>.unaryMinus(): TorchTensor<T> =
public operator fun TorchTensorType.unaryMinus(): TorchTensorType =
wrap(unary_minus(this.tensorHandle)!!)
public fun TorchTensor<T>.abs(): TorchTensor<T> = wrap(abs_tensor(tensorHandle)!!)
public fun TorchTensorType.abs(): TorchTensorType = wrap(abs_tensor(tensorHandle)!!)
public fun TorchTensorType.absAssign(): Unit {
abs_tensor_assign(tensorHandle)
}
public fun TorchTensor<T>.sum(): TorchTensor<T> = wrap(sum_tensor(tensorHandle)!!)
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType =
wrap(transpose_tensor(tensorHandle, i, j)!!)
public fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
transpose_tensor_assign(tensorHandle, i, j)
}
public fun TorchTensor<T>.transpose(i: Int, j: Int): TorchTensor<T> =
wrap(transpose_tensor(tensorHandle, i , j)!!)
public fun TorchTensorType.sum(): TorchTensorType = wrap(sum_tensor(tensorHandle)!!)
public fun TorchTensorType.sumAssign(): Unit {
sum_tensor_assign(tensorHandle)
}
public infix fun TorchTensor<T>.grad(variable: TorchTensor<T>): TorchTensor<T> =
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
public fun TorchTensorType.copy(): TorchTensorType =
wrap(tensorHandle = copy_tensor(this.tensorHandle)!!)
public fun TorchTensorType.copyToDevice(device: TorchDevice): TorchTensorType =
wrap(tensorHandle = copy_to_device(this.tensorHandle, device.toInt())!!)
}
public sealed class TorchTensorFieldAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayType>(scope: DeferScope) :
TorchTensorAlgebra<T, TVar, PrimitiveArrayType>(scope) {
public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T>
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T>
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) :
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope) {
public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType
public fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
public fun TorchTensorType.expAssign(): Unit {
exp_tensor_assign(tensorHandle)
}
public fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!)
public fun TorchTensorType.logAssign(): Unit {
log_tensor_assign(tensorHandle)
}
}
public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra<Double, DoubleVar, DoubleArray>(scope) {
public class TorchTensorRealAlgebra(scope: DeferScope) :
TorchTensorFieldAlgebra<Double, DoubleVar, DoubleArray, TorchTensorReal>(scope) {
override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal =
TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
override fun TorchTensor<Double>.copyToArray(): DoubleArray =
override fun TorchTensorReal.copyToArray(): DoubleArray =
this.elements().map { it.second }.toList().toDoubleArray()
override fun copyFromArray(
@ -120,8 +156,8 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
)!!
)
override fun TorchTensor<Double>.getData(): CPointer<DoubleVar> {
require(this.device is TorchDevice.TorchCPU){
override fun TorchTensorReal.getData(): CPointer<DoubleVar> {
require(this.device is TorchDevice.TorchCPU) {
"This tensor is not on available on CPU"
}
return get_data_double(this.tensorHandle)!!
@ -137,50 +173,50 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!!
)
override operator fun Double.plus(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = plus_double(this, other.tensorHandle)!!
)
override fun TorchTensor<Double>.plus(value: Double): TorchTensorReal = TorchTensorReal(
override fun TorchTensorReal.plus(value: Double): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = plus_double(value, this.tensorHandle)!!
)
override fun TorchTensor<Double>.plusAssign(value: Double): Unit {
plus_assign_double(value, this.tensorHandle)
override fun TorchTensorReal.plusAssign(value: Double): Unit {
plus_double_assign(value, this.tensorHandle)
}
override operator fun Double.minus(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = plus_double(-this, other.tensorHandle)!!
)
override fun TorchTensor<Double>.minus(value: Double): TorchTensorReal = TorchTensorReal(
override fun TorchTensorReal.minus(value: Double): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = plus_double(-value, this.tensorHandle)!!
)
override fun TorchTensor<Double>.minusAssign(value: Double): Unit {
plus_assign_double(-value, this.tensorHandle)
override fun TorchTensorReal.minusAssign(value: Double): Unit {
plus_double_assign(-value, this.tensorHandle)
}
override operator fun Double.times(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
override operator fun Double.times(other: TorchTensorReal): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = times_double(this, other.tensorHandle)!!
)
override fun TorchTensor<Double>.times(value: Double): TorchTensorReal = TorchTensorReal(
override fun TorchTensorReal.times(value: Double): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = times_double(value, this.tensorHandle)!!
)
override fun TorchTensor<Double>.timesAssign(value: Double): Unit {
times_assign_double(value, this.tensorHandle)
override fun TorchTensorReal.timesAssign(value: Double): Unit {
times_double_assign(value, this.tensorHandle)
}
override fun full(value: Double, shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal(
override fun full(value: Double, shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = full_double(value, shape.toCValues(), shape.size, device.toInt())!!
)

View File

@ -0,0 +1,28 @@
package kscience.kmath.torch
import kotlin.test.*
internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit {
TorchTensorRealAlgebra {
setSeed(SEED)
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
tensorX.requiresGrad = true
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
val tensorSigma = randFeatures + randFeatures.transpose(0,1)
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
val expressionAtX =
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
val gradientAtX = expressionAtX grad tensorX
val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
assertTrue(error < TOLERANCE)
}
}
internal class TestAutograd {
@Test
fun testAutoGrad() = testingAutoGrad(dim = 100)
}

View File

@ -3,8 +3,7 @@ package kscience.kmath.torch
import kscience.kmath.linear.RealMatrixContext
import kscience.kmath.operations.invoke
import kscience.kmath.structures.Matrix
import kotlin.math.abs
import kotlin.math.exp
import kotlin.math.*
import kotlin.test.*
internal fun testingScalarProduct(device: TorchDevice = TorchDevice.TorchCPU): Unit {
@ -40,7 +39,7 @@ internal fun testingMatrixMultiplication(device: TorchDevice = TorchDevice.Torch
lhsTensorCopy dotAssign rhsTensor
lhsTensor dotRightAssign rhsTensorCopy
var error: Double = 0.0
var error = 0.0
product.elements().forEach {
error += abs(expected[it.first] - it.second) +
abs(expected[it.first] - lhsTensorCopy[it.first]) +
@ -85,27 +84,24 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU):
}
}
internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit {
internal fun testingTensorTransformations(device: TorchDevice = TorchDevice.TorchCPU): Unit {
TorchTensorRealAlgebra {
setSeed(SEED)
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
tensorX.requiresGrad = true
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
val tensorSigma = randFeatures + randFeatures.transpose(0,1)
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
val expressionAtX =
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
val gradientAtX = expressionAtX grad tensorX
val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
assertTrue(error < TOLERANCE)
val tensor = randNormal(shape = intArrayOf(3, 3), device = device)
val result = tensor.exp().log()
val assignResult = tensor.copy()
assignResult.transposeAssign(0,1)
assignResult.expAssign()
assignResult.logAssign()
assignResult.transposeAssign(0,1)
val error = tensor - result
error.absAssign()
error.sumAssign()
error += (tensor - assignResult).abs().sum()
assertTrue (error.value()< TOLERANCE)
}
}
internal class TestTorchTensorAlgebra {
@Test
@ -118,6 +114,6 @@ internal class TestTorchTensorAlgebra {
fun testLinearStructure() = testingLinearStructure()
@Test
fun testAutoGrad() = testingAutoGrad(dim = 100)
fun testTensorTransformations() = testingTensorTransformations()
}