Tensor transformations first examples
This commit is contained in:
parent
39f3a87bbd
commit
ca3cca65ef
@ -73,20 +73,20 @@ extern "C"
|
|||||||
TorchTensorHandle times_long(long value, TorchTensorHandle other);
|
TorchTensorHandle times_long(long value, TorchTensorHandle other);
|
||||||
TorchTensorHandle times_int(int value, TorchTensorHandle other);
|
TorchTensorHandle times_int(int value, TorchTensorHandle other);
|
||||||
|
|
||||||
void times_assign_double(double value, TorchTensorHandle other);
|
void times_double_assign(double value, TorchTensorHandle other);
|
||||||
void times_assign_float(float value, TorchTensorHandle other);
|
void times_float_assign(float value, TorchTensorHandle other);
|
||||||
void times_assign_long(long value, TorchTensorHandle other);
|
void times_long_assign(long value, TorchTensorHandle other);
|
||||||
void times_assign_int(int value, TorchTensorHandle other);
|
void times_int_assign(int value, TorchTensorHandle other);
|
||||||
|
|
||||||
TorchTensorHandle plus_double(double value, TorchTensorHandle other);
|
TorchTensorHandle plus_double(double value, TorchTensorHandle other);
|
||||||
TorchTensorHandle plus_float(float value, TorchTensorHandle other);
|
TorchTensorHandle plus_float(float value, TorchTensorHandle other);
|
||||||
TorchTensorHandle plus_long(long value, TorchTensorHandle other);
|
TorchTensorHandle plus_long(long value, TorchTensorHandle other);
|
||||||
TorchTensorHandle plus_int(int value, TorchTensorHandle other);
|
TorchTensorHandle plus_int(int value, TorchTensorHandle other);
|
||||||
|
|
||||||
void plus_assign_double(double value, TorchTensorHandle other);
|
void plus_double_assign(double value, TorchTensorHandle other);
|
||||||
void plus_assign_float(float value, TorchTensorHandle other);
|
void plus_float_assign(float value, TorchTensorHandle other);
|
||||||
void plus_assign_long(long value, TorchTensorHandle other);
|
void plus_long_assign(long value, TorchTensorHandle other);
|
||||||
void plus_assign_int(int value, TorchTensorHandle other);
|
void plus_int_assign(int value, TorchTensorHandle other);
|
||||||
|
|
||||||
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
@ -96,10 +96,22 @@ extern "C"
|
|||||||
void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
TorchTensorHandle unary_minus(TorchTensorHandle tensor);
|
TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle);
|
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle);
|
||||||
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
|
void abs_tensor_assign(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j);
|
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j);
|
||||||
|
void transpose_tensor_assign(TorchTensorHandle tensor_handle, int i, int j);
|
||||||
|
|
||||||
|
TorchTensorHandle exp_tensor(TorchTensorHandle tensor_handle);
|
||||||
|
void exp_tensor_assign(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
TorchTensorHandle log_tensor(TorchTensorHandle tensor_handle);
|
||||||
|
void log_tensor_assign(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
|
||||||
|
void sum_tensor_assign(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
bool requires_grad(TorchTensorHandle tensor_handle);
|
bool requires_grad(TorchTensorHandle tensor_handle);
|
||||||
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
|
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
|
||||||
|
@ -219,19 +219,19 @@ TorchTensorHandle plus_int(int value, TorchTensorHandle other)
|
|||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::cast(other) + value);
|
return new torch::Tensor(ctorch::cast(other) + value);
|
||||||
}
|
}
|
||||||
void plus_assign_double(double value, TorchTensorHandle other)
|
void plus_double_assign(double value, TorchTensorHandle other)
|
||||||
{
|
{
|
||||||
ctorch::cast(other) += value;
|
ctorch::cast(other) += value;
|
||||||
}
|
}
|
||||||
void plus_assign_float(float value, TorchTensorHandle other)
|
void plus_float_assign(float value, TorchTensorHandle other)
|
||||||
{
|
{
|
||||||
ctorch::cast(other) += value;
|
ctorch::cast(other) += value;
|
||||||
}
|
}
|
||||||
void plus_assign_long(long value, TorchTensorHandle other)
|
void plus_long_assign(long value, TorchTensorHandle other)
|
||||||
{
|
{
|
||||||
ctorch::cast(other) += value;
|
ctorch::cast(other) += value;
|
||||||
}
|
}
|
||||||
void plus_assign_int(int value, TorchTensorHandle other)
|
void plus_int_assign(int value, TorchTensorHandle other)
|
||||||
{
|
{
|
||||||
ctorch::cast(other) += value;
|
ctorch::cast(other) += value;
|
||||||
}
|
}
|
||||||
@ -252,19 +252,19 @@ TorchTensorHandle times_int(int value, TorchTensorHandle other)
|
|||||||
{
|
{
|
||||||
return new torch::Tensor(value * ctorch::cast(other));
|
return new torch::Tensor(value * ctorch::cast(other));
|
||||||
}
|
}
|
||||||
void times_assign_double(double value, TorchTensorHandle other)
|
void times_double_assign(double value, TorchTensorHandle other)
|
||||||
{
|
{
|
||||||
ctorch::cast(other) *= value;
|
ctorch::cast(other) *= value;
|
||||||
}
|
}
|
||||||
void times_assign_float(float value, TorchTensorHandle other)
|
void times_float_assign(float value, TorchTensorHandle other)
|
||||||
{
|
{
|
||||||
ctorch::cast(other) *= value;
|
ctorch::cast(other) *= value;
|
||||||
}
|
}
|
||||||
void times_assign_long(long value, TorchTensorHandle other)
|
void times_long_assign(long value, TorchTensorHandle other)
|
||||||
{
|
{
|
||||||
ctorch::cast(other) *= value;
|
ctorch::cast(other) *= value;
|
||||||
}
|
}
|
||||||
void times_assign_int(int value, TorchTensorHandle other)
|
void times_int_assign(int value, TorchTensorHandle other)
|
||||||
{
|
{
|
||||||
ctorch::cast(other) *= value;
|
ctorch::cast(other) *= value;
|
||||||
}
|
}
|
||||||
@ -301,21 +301,54 @@ void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
|||||||
{
|
{
|
||||||
ctorch::cast(lhs) /= ctorch::cast(rhs);
|
ctorch::cast(lhs) /= ctorch::cast(rhs);
|
||||||
}
|
}
|
||||||
TorchTensorHandle unary_minus(TorchTensorHandle tensor)
|
TorchTensorHandle unary_minus(TorchTensorHandle tensor_handle)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(-ctorch::cast(tensor));
|
return new torch::Tensor(-ctorch::cast(tensor_handle));
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle)
|
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::cast(tensor_handle).abs());
|
return new torch::Tensor(ctorch::cast(tensor_handle).abs());
|
||||||
}
|
}
|
||||||
|
void abs_tensor_assign(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).abs();
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j));
|
||||||
|
}
|
||||||
|
void transpose_tensor_assign(TorchTensorHandle tensor_handle, int i, int j)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).transpose(i, j);
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle exp_tensor(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).exp());
|
||||||
|
}
|
||||||
|
void exp_tensor_assign(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).exp();
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle log_tensor(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).log());
|
||||||
|
}
|
||||||
|
void log_tensor_assign(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).log();
|
||||||
|
}
|
||||||
|
|
||||||
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle)
|
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::cast(tensor_handle).sum());
|
return new torch::Tensor(ctorch::cast(tensor_handle).sum());
|
||||||
}
|
}
|
||||||
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j)
|
void sum_tensor_assign(TorchTensorHandle tensor_handle)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j));
|
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).sum();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool requires_grad(TorchTensorHandle tensor_handle)
|
bool requires_grad(TorchTensorHandle tensor_handle)
|
||||||
|
@ -0,0 +1,9 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.*
|
||||||
|
|
||||||
|
|
||||||
|
internal class TestAutogradGPU {
|
||||||
|
@Test
|
||||||
|
fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0))
|
||||||
|
}
|
@ -18,5 +18,7 @@ class TestTorchTensorAlgebraGPU {
|
|||||||
testingLinearStructure(device = TorchDevice.TorchCUDA(0))
|
testingLinearStructure(device = TorchDevice.TorchCUDA(0))
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0))
|
fun testTensorTransformations() =
|
||||||
|
testingTensorTransformations(device = TorchDevice.TorchCUDA(0))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,6 @@ public sealed class TorchTensor<T> constructor(
|
|||||||
private fun close(): Unit = dispose_tensor(tensorHandle)
|
private fun close(): Unit = dispose_tensor(tensorHandle)
|
||||||
|
|
||||||
protected abstract fun item(): T
|
protected abstract fun item(): T
|
||||||
internal abstract fun wrap(outScope: DeferScope, outTensorHandle: COpaquePointer): TorchTensor<T>
|
|
||||||
|
|
||||||
override val dimension: Int get() = get_dim(tensorHandle)
|
override val dimension: Int get() = get_dim(tensorHandle)
|
||||||
override val shape: IntArray
|
override val shape: IntArray
|
||||||
@ -50,18 +49,6 @@ public sealed class TorchTensor<T> constructor(
|
|||||||
return item()
|
return item()
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun copy(): TorchTensor<T> =
|
|
||||||
wrap(
|
|
||||||
outScope = scope,
|
|
||||||
outTensorHandle = copy_tensor(tensorHandle)!!
|
|
||||||
)
|
|
||||||
|
|
||||||
public fun copyToDevice(device: TorchDevice): TorchTensor<T> =
|
|
||||||
wrap(
|
|
||||||
outScope = scope,
|
|
||||||
outTensorHandle = copy_to_device(tensorHandle, device.toInt())!!
|
|
||||||
)
|
|
||||||
|
|
||||||
public var requiresGrad: Boolean
|
public var requiresGrad: Boolean
|
||||||
get() = requires_grad(tensorHandle)
|
get() = requires_grad(tensorHandle)
|
||||||
set(value) = requires_grad_(tensorHandle, value)
|
set(value) = requires_grad_(tensorHandle, value)
|
||||||
@ -76,9 +63,6 @@ public class TorchTensorReal internal constructor(
|
|||||||
tensorHandle: COpaquePointer
|
tensorHandle: COpaquePointer
|
||||||
) : TorchTensor<Double>(scope, tensorHandle) {
|
) : TorchTensor<Double>(scope, tensorHandle) {
|
||||||
override fun item(): Double = get_item_double(tensorHandle)
|
override fun item(): Double = get_item_double(tensorHandle)
|
||||||
override fun wrap(outScope: DeferScope, outTensorHandle: COpaquePointer
|
|
||||||
): TorchTensorReal = TorchTensorReal(scope = outScope, tensorHandle = outTensorHandle)
|
|
||||||
|
|
||||||
override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues())
|
override fun get(index: IntArray): Double = get_double(tensorHandle, index.toCValues())
|
||||||
override fun set(index: IntArray, value: Double) {
|
override fun set(index: IntArray, value: Double) {
|
||||||
set_double(tensorHandle, index.toCValues(), value)
|
set_double(tensorHandle, index.toCValues(), value)
|
||||||
|
@ -3,93 +3,129 @@ package kscience.kmath.torch
|
|||||||
import kotlinx.cinterop.*
|
import kotlinx.cinterop.*
|
||||||
import kscience.kmath.ctorch.*
|
import kscience.kmath.ctorch.*
|
||||||
|
|
||||||
public sealed class TorchTensorAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayType> constructor(
|
public sealed class TorchTensorAlgebra<
|
||||||
|
T,
|
||||||
|
TVar : CPrimitiveVar,
|
||||||
|
PrimitiveArrayType,
|
||||||
|
TorchTensorType : TorchTensor<T>> constructor(
|
||||||
internal val scope: DeferScope
|
internal val scope: DeferScope
|
||||||
) {
|
) {
|
||||||
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensor<T>
|
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensorType
|
||||||
|
|
||||||
public abstract fun copyFromArray(
|
public abstract fun copyFromArray(
|
||||||
array: PrimitiveArrayType,
|
array: PrimitiveArrayType,
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
device: TorchDevice = TorchDevice.TorchCPU
|
device: TorchDevice = TorchDevice.TorchCPU
|
||||||
): TorchTensor<T>
|
): TorchTensorType
|
||||||
public abstract fun TorchTensor<T>.copyToArray(): PrimitiveArrayType
|
|
||||||
|
|
||||||
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensor<T>
|
public abstract fun TorchTensorType.copyToArray(): PrimitiveArrayType
|
||||||
public abstract fun TorchTensor<T>.getData(): CPointer<TVar>
|
|
||||||
|
|
||||||
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor<T>
|
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensorType
|
||||||
|
public abstract fun TorchTensorType.getData(): CPointer<TVar>
|
||||||
|
|
||||||
public abstract operator fun T.plus(other: TorchTensor<T>): TorchTensor<T>
|
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensorType
|
||||||
public abstract operator fun TorchTensor<T>.plus(value: T): TorchTensor<T>
|
|
||||||
public abstract operator fun TorchTensor<T>.plusAssign(value: T): Unit
|
|
||||||
public abstract operator fun T.minus(other: TorchTensor<T>): TorchTensor<T>
|
|
||||||
public abstract operator fun TorchTensor<T>.minus(value: T): TorchTensor<T>
|
|
||||||
public abstract operator fun TorchTensor<T>.minusAssign(value: T): Unit
|
|
||||||
public abstract operator fun T.times(other: TorchTensor<T>): TorchTensor<T>
|
|
||||||
public abstract operator fun TorchTensor<T>.times(value: T): TorchTensor<T>
|
|
||||||
public abstract operator fun TorchTensor<T>.timesAssign(value: T): Unit
|
|
||||||
|
|
||||||
public operator fun TorchTensor<T>.times(other: TorchTensor<T>): TorchTensor<T> =
|
public abstract operator fun T.plus(other: TorchTensorType): TorchTensorType
|
||||||
|
public abstract operator fun TorchTensorType.plus(value: T): TorchTensorType
|
||||||
|
public abstract operator fun TorchTensorType.plusAssign(value: T): Unit
|
||||||
|
public abstract operator fun T.minus(other: TorchTensorType): TorchTensorType
|
||||||
|
public abstract operator fun TorchTensorType.minus(value: T): TorchTensorType
|
||||||
|
public abstract operator fun TorchTensorType.minusAssign(value: T): Unit
|
||||||
|
public abstract operator fun T.times(other: TorchTensorType): TorchTensorType
|
||||||
|
public abstract operator fun TorchTensorType.times(value: T): TorchTensorType
|
||||||
|
public abstract operator fun TorchTensorType.timesAssign(value: T): Unit
|
||||||
|
|
||||||
|
public operator fun TorchTensorType.times(other: TorchTensorType): TorchTensorType =
|
||||||
wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!)
|
wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
public operator fun TorchTensor<T>.timesAssign(other: TorchTensor<T>): Unit {
|
|
||||||
|
public operator fun TorchTensorType.timesAssign(other: TorchTensorType): Unit {
|
||||||
times_tensor_assign(this.tensorHandle, other.tensorHandle)
|
times_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
public operator fun TorchTensor<T>.div(other: TorchTensor<T>): TorchTensor<T> =
|
|
||||||
|
public operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType =
|
||||||
wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!)
|
wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
public operator fun TorchTensor<T>.divAssign(other: TorchTensor<T>): Unit {
|
|
||||||
|
public operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
|
||||||
div_tensor_assign(this.tensorHandle, other.tensorHandle)
|
div_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
public infix fun TorchTensor<T>.dot(other: TorchTensor<T>): TorchTensor<T> =
|
public infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType =
|
||||||
wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
|
wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
|
||||||
|
|
||||||
public infix fun TorchTensor<T>.dotAssign(other: TorchTensor<T>): Unit {
|
public infix fun TorchTensorType.dotAssign(other: TorchTensorType): Unit {
|
||||||
matmul_assign(this.tensorHandle, other.tensorHandle)
|
matmul_assign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
public infix fun TorchTensor<T>.dotRightAssign(other: TorchTensor<T>): Unit {
|
public infix fun TorchTensorType.dotRightAssign(other: TorchTensorType): Unit {
|
||||||
matmul_right_assign(this.tensorHandle, other.tensorHandle)
|
matmul_right_assign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
public operator fun TorchTensor<T>.plus(other: TorchTensor<T>): TorchTensor<T> =
|
public operator fun TorchTensorType.plus(other: TorchTensorType): TorchTensorType =
|
||||||
wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
|
|
||||||
public operator fun TorchTensor<T>.plusAssign(other: TorchTensor<T>): Unit {
|
public operator fun TorchTensorType.plusAssign(other: TorchTensorType): Unit {
|
||||||
plus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
plus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
public operator fun TorchTensor<T>.minus(other: TorchTensor<T>): TorchTensor<T> =
|
public operator fun TorchTensorType.minus(other: TorchTensorType): TorchTensorType =
|
||||||
wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
|
|
||||||
public operator fun TorchTensor<T>.minusAssign(other: TorchTensor<T>): Unit {
|
public operator fun TorchTensorType.minusAssign(other: TorchTensorType): Unit {
|
||||||
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
public operator fun TorchTensor<T>.unaryMinus(): TorchTensor<T> =
|
|
||||||
|
public operator fun TorchTensorType.unaryMinus(): TorchTensorType =
|
||||||
wrap(unary_minus(this.tensorHandle)!!)
|
wrap(unary_minus(this.tensorHandle)!!)
|
||||||
|
|
||||||
public fun TorchTensor<T>.abs(): TorchTensor<T> = wrap(abs_tensor(tensorHandle)!!)
|
public fun TorchTensorType.abs(): TorchTensorType = wrap(abs_tensor(tensorHandle)!!)
|
||||||
|
public fun TorchTensorType.absAssign(): Unit {
|
||||||
|
abs_tensor_assign(tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
public fun TorchTensor<T>.sum(): TorchTensor<T> = wrap(sum_tensor(tensorHandle)!!)
|
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType =
|
||||||
|
wrap(transpose_tensor(tensorHandle, i, j)!!)
|
||||||
|
public fun TorchTensorType.transposeAssign(i: Int, j: Int): Unit {
|
||||||
|
transpose_tensor_assign(tensorHandle, i, j)
|
||||||
|
}
|
||||||
|
|
||||||
public fun TorchTensor<T>.transpose(i: Int, j: Int): TorchTensor<T> =
|
public fun TorchTensorType.sum(): TorchTensorType = wrap(sum_tensor(tensorHandle)!!)
|
||||||
wrap(transpose_tensor(tensorHandle, i , j)!!)
|
public fun TorchTensorType.sumAssign(): Unit {
|
||||||
|
sum_tensor_assign(tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
public infix fun TorchTensor<T>.grad(variable: TorchTensor<T>): TorchTensor<T> =
|
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
||||||
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||||
|
|
||||||
|
public fun TorchTensorType.copy(): TorchTensorType =
|
||||||
|
wrap(tensorHandle = copy_tensor(this.tensorHandle)!!)
|
||||||
|
public fun TorchTensorType.copyToDevice(device: TorchDevice): TorchTensorType =
|
||||||
|
wrap(tensorHandle = copy_to_device(this.tensorHandle, device.toInt())!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class TorchTensorFieldAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayType>(scope: DeferScope) :
|
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
||||||
TorchTensorAlgebra<T, TVar, PrimitiveArrayType>(scope) {
|
PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) :
|
||||||
public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T>
|
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope) {
|
||||||
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T>
|
public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType
|
||||||
|
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType
|
||||||
|
|
||||||
|
public fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
|
||||||
|
public fun TorchTensorType.expAssign(): Unit {
|
||||||
|
exp_tensor_assign(tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!)
|
||||||
|
public fun TorchTensorType.logAssign(): Unit {
|
||||||
|
log_tensor_assign(tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra<Double, DoubleVar, DoubleArray>(scope) {
|
public class TorchTensorRealAlgebra(scope: DeferScope) :
|
||||||
|
TorchTensorFieldAlgebra<Double, DoubleVar, DoubleArray, TorchTensorReal>(scope) {
|
||||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal =
|
override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal =
|
||||||
TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
|
TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
override fun TorchTensor<Double>.copyToArray(): DoubleArray =
|
override fun TorchTensorReal.copyToArray(): DoubleArray =
|
||||||
this.elements().map { it.second }.toList().toDoubleArray()
|
this.elements().map { it.second }.toList().toDoubleArray()
|
||||||
|
|
||||||
override fun copyFromArray(
|
override fun copyFromArray(
|
||||||
@ -120,8 +156,8 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
|
|||||||
)!!
|
)!!
|
||||||
)
|
)
|
||||||
|
|
||||||
override fun TorchTensor<Double>.getData(): CPointer<DoubleVar> {
|
override fun TorchTensorReal.getData(): CPointer<DoubleVar> {
|
||||||
require(this.device is TorchDevice.TorchCPU){
|
require(this.device is TorchDevice.TorchCPU) {
|
||||||
"This tensor is not on available on CPU"
|
"This tensor is not on available on CPU"
|
||||||
}
|
}
|
||||||
return get_data_double(this.tensorHandle)!!
|
return get_data_double(this.tensorHandle)!!
|
||||||
@ -137,46 +173,46 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
|
|||||||
tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!!
|
tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!!
|
||||||
)
|
)
|
||||||
|
|
||||||
override operator fun Double.plus(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
|
override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal = TorchTensorReal(
|
||||||
scope = scope,
|
scope = scope,
|
||||||
tensorHandle = plus_double(this, other.tensorHandle)!!
|
tensorHandle = plus_double(this, other.tensorHandle)!!
|
||||||
)
|
)
|
||||||
|
|
||||||
override fun TorchTensor<Double>.plus(value: Double): TorchTensorReal = TorchTensorReal(
|
override fun TorchTensorReal.plus(value: Double): TorchTensorReal = TorchTensorReal(
|
||||||
scope = scope,
|
scope = scope,
|
||||||
tensorHandle = plus_double(value, this.tensorHandle)!!
|
tensorHandle = plus_double(value, this.tensorHandle)!!
|
||||||
)
|
)
|
||||||
|
|
||||||
override fun TorchTensor<Double>.plusAssign(value: Double): Unit {
|
override fun TorchTensorReal.plusAssign(value: Double): Unit {
|
||||||
plus_assign_double(value, this.tensorHandle)
|
plus_double_assign(value, this.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun Double.minus(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
|
override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal = TorchTensorReal(
|
||||||
scope = scope,
|
scope = scope,
|
||||||
tensorHandle = plus_double(-this, other.tensorHandle)!!
|
tensorHandle = plus_double(-this, other.tensorHandle)!!
|
||||||
)
|
)
|
||||||
|
|
||||||
override fun TorchTensor<Double>.minus(value: Double): TorchTensorReal = TorchTensorReal(
|
override fun TorchTensorReal.minus(value: Double): TorchTensorReal = TorchTensorReal(
|
||||||
scope = scope,
|
scope = scope,
|
||||||
tensorHandle = plus_double(-value, this.tensorHandle)!!
|
tensorHandle = plus_double(-value, this.tensorHandle)!!
|
||||||
)
|
)
|
||||||
|
|
||||||
override fun TorchTensor<Double>.minusAssign(value: Double): Unit {
|
override fun TorchTensorReal.minusAssign(value: Double): Unit {
|
||||||
plus_assign_double(-value, this.tensorHandle)
|
plus_double_assign(-value, this.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun Double.times(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
|
override operator fun Double.times(other: TorchTensorReal): TorchTensorReal = TorchTensorReal(
|
||||||
scope = scope,
|
scope = scope,
|
||||||
tensorHandle = times_double(this, other.tensorHandle)!!
|
tensorHandle = times_double(this, other.tensorHandle)!!
|
||||||
)
|
)
|
||||||
|
|
||||||
override fun TorchTensor<Double>.times(value: Double): TorchTensorReal = TorchTensorReal(
|
override fun TorchTensorReal.times(value: Double): TorchTensorReal = TorchTensorReal(
|
||||||
scope = scope,
|
scope = scope,
|
||||||
tensorHandle = times_double(value, this.tensorHandle)!!
|
tensorHandle = times_double(value, this.tensorHandle)!!
|
||||||
)
|
)
|
||||||
|
|
||||||
override fun TorchTensor<Double>.timesAssign(value: Double): Unit {
|
override fun TorchTensorReal.timesAssign(value: Double): Unit {
|
||||||
times_assign_double(value, this.tensorHandle)
|
times_double_assign(value, this.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -0,0 +1,28 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.*
|
||||||
|
|
||||||
|
internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||||
|
TorchTensorRealAlgebra {
|
||||||
|
setSeed(SEED)
|
||||||
|
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
|
tensorX.requiresGrad = true
|
||||||
|
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
|
||||||
|
val tensorSigma = randFeatures + randFeatures.transpose(0,1)
|
||||||
|
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
|
|
||||||
|
val expressionAtX =
|
||||||
|
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
|
||||||
|
|
||||||
|
val gradientAtX = expressionAtX grad tensorX
|
||||||
|
val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu
|
||||||
|
|
||||||
|
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
|
||||||
|
assertTrue(error < TOLERANCE)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class TestAutograd {
|
||||||
|
@Test
|
||||||
|
fun testAutoGrad() = testingAutoGrad(dim = 100)
|
||||||
|
}
|
@ -3,8 +3,7 @@ package kscience.kmath.torch
|
|||||||
import kscience.kmath.linear.RealMatrixContext
|
import kscience.kmath.linear.RealMatrixContext
|
||||||
import kscience.kmath.operations.invoke
|
import kscience.kmath.operations.invoke
|
||||||
import kscience.kmath.structures.Matrix
|
import kscience.kmath.structures.Matrix
|
||||||
import kotlin.math.abs
|
import kotlin.math.*
|
||||||
import kotlin.math.exp
|
|
||||||
import kotlin.test.*
|
import kotlin.test.*
|
||||||
|
|
||||||
internal fun testingScalarProduct(device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
internal fun testingScalarProduct(device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||||
@ -40,7 +39,7 @@ internal fun testingMatrixMultiplication(device: TorchDevice = TorchDevice.Torch
|
|||||||
lhsTensorCopy dotAssign rhsTensor
|
lhsTensorCopy dotAssign rhsTensor
|
||||||
lhsTensor dotRightAssign rhsTensorCopy
|
lhsTensor dotRightAssign rhsTensorCopy
|
||||||
|
|
||||||
var error: Double = 0.0
|
var error = 0.0
|
||||||
product.elements().forEach {
|
product.elements().forEach {
|
||||||
error += abs(expected[it.first] - it.second) +
|
error += abs(expected[it.first] - it.second) +
|
||||||
abs(expected[it.first] - lhsTensorCopy[it.first]) +
|
abs(expected[it.first] - lhsTensorCopy[it.first]) +
|
||||||
@ -85,27 +84,24 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
internal fun testingTensorTransformations(device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||||
TorchTensorRealAlgebra {
|
TorchTensorRealAlgebra {
|
||||||
setSeed(SEED)
|
setSeed(SEED)
|
||||||
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
|
val tensor = randNormal(shape = intArrayOf(3, 3), device = device)
|
||||||
tensorX.requiresGrad = true
|
val result = tensor.exp().log()
|
||||||
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
|
val assignResult = tensor.copy()
|
||||||
val tensorSigma = randFeatures + randFeatures.transpose(0,1)
|
assignResult.transposeAssign(0,1)
|
||||||
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
|
assignResult.expAssign()
|
||||||
|
assignResult.logAssign()
|
||||||
val expressionAtX =
|
assignResult.transposeAssign(0,1)
|
||||||
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
|
val error = tensor - result
|
||||||
|
error.absAssign()
|
||||||
val gradientAtX = expressionAtX grad tensorX
|
error.sumAssign()
|
||||||
val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu
|
error += (tensor - assignResult).abs().sum()
|
||||||
|
assertTrue (error.value()< TOLERANCE)
|
||||||
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
|
|
||||||
assertTrue(error < TOLERANCE)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
internal class TestTorchTensorAlgebra {
|
internal class TestTorchTensorAlgebra {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -118,6 +114,6 @@ internal class TestTorchTensorAlgebra {
|
|||||||
fun testLinearStructure() = testingLinearStructure()
|
fun testLinearStructure() = testingLinearStructure()
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testAutoGrad() = testingAutoGrad(dim = 100)
|
fun testTensorTransformations() = testingTensorTransformations()
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user