More linear operations on tensors

This commit is contained in:
rgrit91 2021-01-12 21:33:07 +00:00
parent 524b1d80d1
commit 39f3a87bbd
5 changed files with 139 additions and 13 deletions

View File

@ -106,10 +106,9 @@ TorchTensorRealAlgebra {
val X = randNormal(shape = intArrayOf(dim,dim), device = device)
val Q = X + X.transpose(0,1)
val mu = randNormal(shape = intArrayOf(dim), device = device)
val c = randNormal(shape = IntArray(0), device = device)
// expression to differentiate w.r.t. x
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + c
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + 25.3
// value of the gradient at x
val gradf = f grad x
}

View File

@ -78,10 +78,25 @@ extern "C"
void times_assign_long(long value, TorchTensorHandle other);
void times_assign_int(int value, TorchTensorHandle other);
TorchTensorHandle plus_double(double value, TorchTensorHandle other);
TorchTensorHandle plus_float(float value, TorchTensorHandle other);
TorchTensorHandle plus_long(long value, TorchTensorHandle other);
TorchTensorHandle plus_int(int value, TorchTensorHandle other);
void plus_assign_double(double value, TorchTensorHandle other);
void plus_assign_float(float value, TorchTensorHandle other);
void plus_assign_long(long value, TorchTensorHandle other);
void plus_assign_int(int value, TorchTensorHandle other);
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
TorchTensorHandle unary_minus(TorchTensorHandle tensor);
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle);
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j);

View File

@ -203,6 +203,39 @@ void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
}
TorchTensorHandle plus_double(double value, TorchTensorHandle other)
{
return new torch::Tensor(ctorch::cast(other) + value);
}
TorchTensorHandle plus_float(float value, TorchTensorHandle other)
{
return new torch::Tensor(ctorch::cast(other) + value);
}
TorchTensorHandle plus_long(long value, TorchTensorHandle other)
{
return new torch::Tensor(ctorch::cast(other) + value);
}
TorchTensorHandle plus_int(int value, TorchTensorHandle other)
{
return new torch::Tensor(ctorch::cast(other) + value);
}
void plus_assign_double(double value, TorchTensorHandle other)
{
ctorch::cast(other) += value;
}
void plus_assign_float(float value, TorchTensorHandle other)
{
ctorch::cast(other) += value;
}
void plus_assign_long(long value, TorchTensorHandle other)
{
ctorch::cast(other) += value;
}
void plus_assign_int(int value, TorchTensorHandle other)
{
ctorch::cast(other) += value;
}
TorchTensorHandle times_double(double value, TorchTensorHandle other)
{
return new torch::Tensor(value * ctorch::cast(other));
@ -252,6 +285,26 @@ void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
ctorch::cast(lhs) -= ctorch::cast(rhs);
}
TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
return new torch::Tensor(ctorch::cast(lhs) * ctorch::cast(rhs));
}
void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
ctorch::cast(lhs) *= ctorch::cast(rhs);
}
TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
return new torch::Tensor(ctorch::cast(lhs) / ctorch::cast(rhs));
}
void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{
ctorch::cast(lhs) /= ctorch::cast(rhs);
}
TorchTensorHandle unary_minus(TorchTensorHandle tensor)
{
return new torch::Tensor(-ctorch::cast(tensor));
}
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle)
{
return new torch::Tensor(ctorch::cast(tensor_handle).abs());

View File

@ -20,10 +20,27 @@ public sealed class TorchTensorAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayTyp
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor<T>
public abstract operator fun T.plus(other: TorchTensor<T>): TorchTensor<T>
public abstract operator fun TorchTensor<T>.plus(value: T): TorchTensor<T>
public abstract operator fun TorchTensor<T>.plusAssign(value: T): Unit
public abstract operator fun T.minus(other: TorchTensor<T>): TorchTensor<T>
public abstract operator fun TorchTensor<T>.minus(value: T): TorchTensor<T>
public abstract operator fun TorchTensor<T>.minusAssign(value: T): Unit
public abstract operator fun T.times(other: TorchTensor<T>): TorchTensor<T>
public abstract operator fun TorchTensor<T>.times(value: T): TorchTensor<T>
public abstract operator fun TorchTensor<T>.timesAssign(value: T): Unit
public operator fun TorchTensor<T>.times(other: TorchTensor<T>): TorchTensor<T> =
wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!)
public operator fun TorchTensor<T>.timesAssign(other: TorchTensor<T>): Unit {
times_tensor_assign(this.tensorHandle, other.tensorHandle)
}
public operator fun TorchTensor<T>.div(other: TorchTensor<T>): TorchTensor<T> =
wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!)
public operator fun TorchTensor<T>.divAssign(other: TorchTensor<T>): Unit {
div_tensor_assign(this.tensorHandle, other.tensorHandle)
}
public infix fun TorchTensor<T>.dot(other: TorchTensor<T>): TorchTensor<T> =
wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
@ -48,6 +65,8 @@ public sealed class TorchTensorAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayTyp
public operator fun TorchTensor<T>.minusAssign(other: TorchTensor<T>): Unit {
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
}
public operator fun TorchTensor<T>.unaryMinus(): TorchTensor<T> =
wrap(unary_minus(this.tensorHandle)!!)
public fun TorchTensor<T>.abs(): TorchTensor<T> = wrap(abs_tensor(tensorHandle)!!)
@ -118,6 +137,34 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!!
)
override operator fun Double.plus(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = plus_double(this, other.tensorHandle)!!
)
override fun TorchTensor<Double>.plus(value: Double): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = plus_double(value, this.tensorHandle)!!
)
override fun TorchTensor<Double>.plusAssign(value: Double): Unit {
plus_assign_double(value, this.tensorHandle)
}
override operator fun Double.minus(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = plus_double(-this, other.tensorHandle)!!
)
override fun TorchTensor<Double>.minus(value: Double): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = plus_double(-value, this.tensorHandle)!!
)
override fun TorchTensor<Double>.minusAssign(value: Double): Unit {
plus_assign_double(-value, this.tensorHandle)
}
override operator fun Double.times(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = times_double(this, other.tensorHandle)!!

View File

@ -56,9 +56,14 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU):
val tensorA = full(value = -4.5, shape = shape, device = device)
val tensorB = full(value = 10.9, shape = shape, device = device)
val tensorC = full(value = 789.3, shape = shape, device = device)
val result = 15.8 * tensorA - 1.5 * tensorB + 0.02 * tensorC
val tensorD = full(value = -72.9, shape = shape, device = device)
val tensorE = full(value = 553.1, shape = shape, device = device)
val result = 15.8 * tensorA - 1.5 * tensorB * (-tensorD) + 0.02 * tensorC / tensorE - 39.4
val expected = copyFromArray(
array = (1..3).map { 15.8 * (-4.5) - 1.5 * 10.9 + 0.02 * 789.3 }.toDoubleArray(),
array = (1..3).map {
15.8 * (-4.5) - 1.5 * 10.9 * 72.9 + 0.02 * 789.3 / 553.1 - 39.4
}
.toDoubleArray(),
shape = shape,
device = device
)
@ -66,10 +71,13 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU):
val assignResult = full(value = 0.0, shape = shape, device = device)
tensorA *= 15.8
tensorB *= 1.5
tensorB *= -tensorD
tensorC *= 0.02
tensorC /= tensorE
assignResult += tensorA
assignResult -= tensorB
assignResult += tensorC
assignResult += -39.4
val error = (expected - result).abs().sum().value() +
(expected - assignResult).abs().sum().value()
@ -80,15 +88,19 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU):
internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit {
TorchTensorRealAlgebra {
setSeed(SEED)
val x = randNormal(shape = intArrayOf(dim), device = device)
x.requiresGrad = true
val X = randNormal(shape = intArrayOf(dim,dim), device = device)
val Q = X + X.transpose(0,1)
val mu = randNormal(shape = intArrayOf(dim), device = device)
val c = randNormal(shape = IntArray(0), device = device)
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + c
val gradf = f grad x
val error = (gradf - ((Q dot x) + mu)).abs().sum().value()
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
tensorX.requiresGrad = true
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
val tensorSigma = randFeatures + randFeatures.transpose(0,1)
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
val expressionAtX =
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
val gradientAtX = expressionAtX grad tensorX
val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
assertTrue(error < TOLERANCE)
}
}