forked from kscience/kmath
More linear operations on tensors
This commit is contained in:
parent
524b1d80d1
commit
39f3a87bbd
@ -106,10 +106,9 @@ TorchTensorRealAlgebra {
|
||||
val X = randNormal(shape = intArrayOf(dim,dim), device = device)
|
||||
val Q = X + X.transpose(0,1)
|
||||
val mu = randNormal(shape = intArrayOf(dim), device = device)
|
||||
val c = randNormal(shape = IntArray(0), device = device)
|
||||
|
||||
// expression to differentiate w.r.t. x
|
||||
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + c
|
||||
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + 25.3
|
||||
// value of the gradient at x
|
||||
val gradf = f grad x
|
||||
}
|
||||
|
@ -78,10 +78,25 @@ extern "C"
|
||||
void times_assign_long(long value, TorchTensorHandle other);
|
||||
void times_assign_int(int value, TorchTensorHandle other);
|
||||
|
||||
TorchTensorHandle plus_double(double value, TorchTensorHandle other);
|
||||
TorchTensorHandle plus_float(float value, TorchTensorHandle other);
|
||||
TorchTensorHandle plus_long(long value, TorchTensorHandle other);
|
||||
TorchTensorHandle plus_int(int value, TorchTensorHandle other);
|
||||
|
||||
void plus_assign_double(double value, TorchTensorHandle other);
|
||||
void plus_assign_float(float value, TorchTensorHandle other);
|
||||
void plus_assign_long(long value, TorchTensorHandle other);
|
||||
void plus_assign_int(int value, TorchTensorHandle other);
|
||||
|
||||
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
TorchTensorHandle unary_minus(TorchTensorHandle tensor);
|
||||
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j);
|
||||
|
@ -203,6 +203,39 @@ void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
TorchTensorHandle plus_double(double value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
}
|
||||
TorchTensorHandle plus_float(float value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
}
|
||||
TorchTensorHandle plus_long(long value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
}
|
||||
TorchTensorHandle plus_int(int value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(other) + value);
|
||||
}
|
||||
void plus_assign_double(double value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
void plus_assign_float(float value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
void plus_assign_long(long value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
void plus_assign_int(int value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) += value;
|
||||
}
|
||||
|
||||
TorchTensorHandle times_double(double value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(value * ctorch::cast(other));
|
||||
@ -252,6 +285,26 @@ void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(lhs) -= ctorch::cast(rhs);
|
||||
}
|
||||
TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(lhs) * ctorch::cast(rhs));
|
||||
}
|
||||
void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(lhs) *= ctorch::cast(rhs);
|
||||
}
|
||||
TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(lhs) / ctorch::cast(rhs));
|
||||
}
|
||||
void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(lhs) /= ctorch::cast(rhs);
|
||||
}
|
||||
TorchTensorHandle unary_minus(TorchTensorHandle tensor)
|
||||
{
|
||||
return new torch::Tensor(-ctorch::cast(tensor));
|
||||
}
|
||||
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).abs());
|
||||
|
@ -20,10 +20,27 @@ public sealed class TorchTensorAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayTyp
|
||||
|
||||
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor<T>
|
||||
|
||||
public abstract operator fun T.plus(other: TorchTensor<T>): TorchTensor<T>
|
||||
public abstract operator fun TorchTensor<T>.plus(value: T): TorchTensor<T>
|
||||
public abstract operator fun TorchTensor<T>.plusAssign(value: T): Unit
|
||||
public abstract operator fun T.minus(other: TorchTensor<T>): TorchTensor<T>
|
||||
public abstract operator fun TorchTensor<T>.minus(value: T): TorchTensor<T>
|
||||
public abstract operator fun TorchTensor<T>.minusAssign(value: T): Unit
|
||||
public abstract operator fun T.times(other: TorchTensor<T>): TorchTensor<T>
|
||||
public abstract operator fun TorchTensor<T>.times(value: T): TorchTensor<T>
|
||||
public abstract operator fun TorchTensor<T>.timesAssign(value: T): Unit
|
||||
|
||||
public operator fun TorchTensor<T>.times(other: TorchTensor<T>): TorchTensor<T> =
|
||||
wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||
public operator fun TorchTensor<T>.timesAssign(other: TorchTensor<T>): Unit {
|
||||
times_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
public operator fun TorchTensor<T>.div(other: TorchTensor<T>): TorchTensor<T> =
|
||||
wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||
public operator fun TorchTensor<T>.divAssign(other: TorchTensor<T>): Unit {
|
||||
div_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
public infix fun TorchTensor<T>.dot(other: TorchTensor<T>): TorchTensor<T> =
|
||||
wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
|
||||
|
||||
@ -48,6 +65,8 @@ public sealed class TorchTensorAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayTyp
|
||||
public operator fun TorchTensor<T>.minusAssign(other: TorchTensor<T>): Unit {
|
||||
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
public operator fun TorchTensor<T>.unaryMinus(): TorchTensor<T> =
|
||||
wrap(unary_minus(this.tensorHandle)!!)
|
||||
|
||||
public fun TorchTensor<T>.abs(): TorchTensor<T> = wrap(abs_tensor(tensorHandle)!!)
|
||||
|
||||
@ -118,6 +137,34 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
|
||||
tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!!
|
||||
)
|
||||
|
||||
override operator fun Double.plus(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = plus_double(this, other.tensorHandle)!!
|
||||
)
|
||||
|
||||
override fun TorchTensor<Double>.plus(value: Double): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = plus_double(value, this.tensorHandle)!!
|
||||
)
|
||||
|
||||
override fun TorchTensor<Double>.plusAssign(value: Double): Unit {
|
||||
plus_assign_double(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Double.minus(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = plus_double(-this, other.tensorHandle)!!
|
||||
)
|
||||
|
||||
override fun TorchTensor<Double>.minus(value: Double): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = plus_double(-value, this.tensorHandle)!!
|
||||
)
|
||||
|
||||
override fun TorchTensor<Double>.minusAssign(value: Double): Unit {
|
||||
plus_assign_double(-value, this.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Double.times(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = times_double(this, other.tensorHandle)!!
|
||||
|
@ -56,9 +56,14 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU):
|
||||
val tensorA = full(value = -4.5, shape = shape, device = device)
|
||||
val tensorB = full(value = 10.9, shape = shape, device = device)
|
||||
val tensorC = full(value = 789.3, shape = shape, device = device)
|
||||
val result = 15.8 * tensorA - 1.5 * tensorB + 0.02 * tensorC
|
||||
val tensorD = full(value = -72.9, shape = shape, device = device)
|
||||
val tensorE = full(value = 553.1, shape = shape, device = device)
|
||||
val result = 15.8 * tensorA - 1.5 * tensorB * (-tensorD) + 0.02 * tensorC / tensorE - 39.4
|
||||
val expected = copyFromArray(
|
||||
array = (1..3).map { 15.8 * (-4.5) - 1.5 * 10.9 + 0.02 * 789.3 }.toDoubleArray(),
|
||||
array = (1..3).map {
|
||||
15.8 * (-4.5) - 1.5 * 10.9 * 72.9 + 0.02 * 789.3 / 553.1 - 39.4
|
||||
}
|
||||
.toDoubleArray(),
|
||||
shape = shape,
|
||||
device = device
|
||||
)
|
||||
@ -66,10 +71,13 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU):
|
||||
val assignResult = full(value = 0.0, shape = shape, device = device)
|
||||
tensorA *= 15.8
|
||||
tensorB *= 1.5
|
||||
tensorB *= -tensorD
|
||||
tensorC *= 0.02
|
||||
tensorC /= tensorE
|
||||
assignResult += tensorA
|
||||
assignResult -= tensorB
|
||||
assignResult += tensorC
|
||||
assignResult += -39.4
|
||||
|
||||
val error = (expected - result).abs().sum().value() +
|
||||
(expected - assignResult).abs().sum().value()
|
||||
@ -80,15 +88,19 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU):
|
||||
internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||
TorchTensorRealAlgebra {
|
||||
setSeed(SEED)
|
||||
val x = randNormal(shape = intArrayOf(dim), device = device)
|
||||
x.requiresGrad = true
|
||||
val X = randNormal(shape = intArrayOf(dim,dim), device = device)
|
||||
val Q = X + X.transpose(0,1)
|
||||
val mu = randNormal(shape = intArrayOf(dim), device = device)
|
||||
val c = randNormal(shape = IntArray(0), device = device)
|
||||
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + c
|
||||
val gradf = f grad x
|
||||
val error = (gradf - ((Q dot x) + mu)).abs().sum().value()
|
||||
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
|
||||
tensorX.requiresGrad = true
|
||||
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
|
||||
val tensorSigma = randFeatures + randFeatures.transpose(0,1)
|
||||
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
|
||||
|
||||
val expressionAtX =
|
||||
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
|
||||
|
||||
val gradientAtX = expressionAtX grad tensorX
|
||||
val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu
|
||||
|
||||
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
|
||||
assertTrue(error < TOLERANCE)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user