More linear operations on tensors
This commit is contained in:
parent
524b1d80d1
commit
39f3a87bbd
@ -106,10 +106,9 @@ TorchTensorRealAlgebra {
|
|||||||
val X = randNormal(shape = intArrayOf(dim,dim), device = device)
|
val X = randNormal(shape = intArrayOf(dim,dim), device = device)
|
||||||
val Q = X + X.transpose(0,1)
|
val Q = X + X.transpose(0,1)
|
||||||
val mu = randNormal(shape = intArrayOf(dim), device = device)
|
val mu = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
val c = randNormal(shape = IntArray(0), device = device)
|
|
||||||
|
|
||||||
// expression to differentiate w.r.t. x
|
// expression to differentiate w.r.t. x
|
||||||
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + c
|
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + 25.3
|
||||||
// value of the gradient at x
|
// value of the gradient at x
|
||||||
val gradf = f grad x
|
val gradf = f grad x
|
||||||
}
|
}
|
||||||
|
@ -78,10 +78,25 @@ extern "C"
|
|||||||
void times_assign_long(long value, TorchTensorHandle other);
|
void times_assign_long(long value, TorchTensorHandle other);
|
||||||
void times_assign_int(int value, TorchTensorHandle other);
|
void times_assign_int(int value, TorchTensorHandle other);
|
||||||
|
|
||||||
|
TorchTensorHandle plus_double(double value, TorchTensorHandle other);
|
||||||
|
TorchTensorHandle plus_float(float value, TorchTensorHandle other);
|
||||||
|
TorchTensorHandle plus_long(long value, TorchTensorHandle other);
|
||||||
|
TorchTensorHandle plus_int(int value, TorchTensorHandle other);
|
||||||
|
|
||||||
|
void plus_assign_double(double value, TorchTensorHandle other);
|
||||||
|
void plus_assign_float(float value, TorchTensorHandle other);
|
||||||
|
void plus_assign_long(long value, TorchTensorHandle other);
|
||||||
|
void plus_assign_int(int value, TorchTensorHandle other);
|
||||||
|
|
||||||
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
TorchTensorHandle unary_minus(TorchTensorHandle tensor);
|
||||||
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle);
|
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle);
|
||||||
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
|
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
|
||||||
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j);
|
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j);
|
||||||
|
@ -203,6 +203,39 @@ void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
|||||||
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle plus_double(double value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(other) + value);
|
||||||
|
}
|
||||||
|
TorchTensorHandle plus_float(float value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(other) + value);
|
||||||
|
}
|
||||||
|
TorchTensorHandle plus_long(long value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(other) + value);
|
||||||
|
}
|
||||||
|
TorchTensorHandle plus_int(int value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(other) + value);
|
||||||
|
}
|
||||||
|
void plus_assign_double(double value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) += value;
|
||||||
|
}
|
||||||
|
void plus_assign_float(float value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) += value;
|
||||||
|
}
|
||||||
|
void plus_assign_long(long value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) += value;
|
||||||
|
}
|
||||||
|
void plus_assign_int(int value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) += value;
|
||||||
|
}
|
||||||
|
|
||||||
TorchTensorHandle times_double(double value, TorchTensorHandle other)
|
TorchTensorHandle times_double(double value, TorchTensorHandle other)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(value * ctorch::cast(other));
|
return new torch::Tensor(value * ctorch::cast(other));
|
||||||
@ -252,6 +285,26 @@ void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
|||||||
{
|
{
|
||||||
ctorch::cast(lhs) -= ctorch::cast(rhs);
|
ctorch::cast(lhs) -= ctorch::cast(rhs);
|
||||||
}
|
}
|
||||||
|
TorchTensorHandle times_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(lhs) * ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
void times_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(lhs) *= ctorch::cast(rhs);
|
||||||
|
}
|
||||||
|
TorchTensorHandle div_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(lhs) / ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
void div_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(lhs) /= ctorch::cast(rhs);
|
||||||
|
}
|
||||||
|
TorchTensorHandle unary_minus(TorchTensorHandle tensor)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(-ctorch::cast(tensor));
|
||||||
|
}
|
||||||
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle)
|
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::cast(tensor_handle).abs());
|
return new torch::Tensor(ctorch::cast(tensor_handle).abs());
|
||||||
|
@ -20,10 +20,27 @@ public sealed class TorchTensorAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayTyp
|
|||||||
|
|
||||||
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor<T>
|
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor<T>
|
||||||
|
|
||||||
|
public abstract operator fun T.plus(other: TorchTensor<T>): TorchTensor<T>
|
||||||
|
public abstract operator fun TorchTensor<T>.plus(value: T): TorchTensor<T>
|
||||||
|
public abstract operator fun TorchTensor<T>.plusAssign(value: T): Unit
|
||||||
|
public abstract operator fun T.minus(other: TorchTensor<T>): TorchTensor<T>
|
||||||
|
public abstract operator fun TorchTensor<T>.minus(value: T): TorchTensor<T>
|
||||||
|
public abstract operator fun TorchTensor<T>.minusAssign(value: T): Unit
|
||||||
public abstract operator fun T.times(other: TorchTensor<T>): TorchTensor<T>
|
public abstract operator fun T.times(other: TorchTensor<T>): TorchTensor<T>
|
||||||
public abstract operator fun TorchTensor<T>.times(value: T): TorchTensor<T>
|
public abstract operator fun TorchTensor<T>.times(value: T): TorchTensor<T>
|
||||||
public abstract operator fun TorchTensor<T>.timesAssign(value: T): Unit
|
public abstract operator fun TorchTensor<T>.timesAssign(value: T): Unit
|
||||||
|
|
||||||
|
public operator fun TorchTensor<T>.times(other: TorchTensor<T>): TorchTensor<T> =
|
||||||
|
wrap(times_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
|
public operator fun TorchTensor<T>.timesAssign(other: TorchTensor<T>): Unit {
|
||||||
|
times_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
public operator fun TorchTensor<T>.div(other: TorchTensor<T>): TorchTensor<T> =
|
||||||
|
wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
|
public operator fun TorchTensor<T>.divAssign(other: TorchTensor<T>): Unit {
|
||||||
|
div_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
public infix fun TorchTensor<T>.dot(other: TorchTensor<T>): TorchTensor<T> =
|
public infix fun TorchTensor<T>.dot(other: TorchTensor<T>): TorchTensor<T> =
|
||||||
wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
|
wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
|
||||||
|
|
||||||
@ -48,6 +65,8 @@ public sealed class TorchTensorAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayTyp
|
|||||||
public operator fun TorchTensor<T>.minusAssign(other: TorchTensor<T>): Unit {
|
public operator fun TorchTensor<T>.minusAssign(other: TorchTensor<T>): Unit {
|
||||||
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
public operator fun TorchTensor<T>.unaryMinus(): TorchTensor<T> =
|
||||||
|
wrap(unary_minus(this.tensorHandle)!!)
|
||||||
|
|
||||||
public fun TorchTensor<T>.abs(): TorchTensor<T> = wrap(abs_tensor(tensorHandle)!!)
|
public fun TorchTensor<T>.abs(): TorchTensor<T> = wrap(abs_tensor(tensorHandle)!!)
|
||||||
|
|
||||||
@ -118,6 +137,34 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
|
|||||||
tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!!
|
tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!!
|
||||||
)
|
)
|
||||||
|
|
||||||
|
override operator fun Double.plus(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = plus_double(this, other.tensorHandle)!!
|
||||||
|
)
|
||||||
|
|
||||||
|
override fun TorchTensor<Double>.plus(value: Double): TorchTensorReal = TorchTensorReal(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = plus_double(value, this.tensorHandle)!!
|
||||||
|
)
|
||||||
|
|
||||||
|
override fun TorchTensor<Double>.plusAssign(value: Double): Unit {
|
||||||
|
plus_assign_double(value, this.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Double.minus(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = plus_double(-this, other.tensorHandle)!!
|
||||||
|
)
|
||||||
|
|
||||||
|
override fun TorchTensor<Double>.minus(value: Double): TorchTensorReal = TorchTensorReal(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = plus_double(-value, this.tensorHandle)!!
|
||||||
|
)
|
||||||
|
|
||||||
|
override fun TorchTensor<Double>.minusAssign(value: Double): Unit {
|
||||||
|
plus_assign_double(-value, this.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
override operator fun Double.times(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
|
override operator fun Double.times(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
|
||||||
scope = scope,
|
scope = scope,
|
||||||
tensorHandle = times_double(this, other.tensorHandle)!!
|
tensorHandle = times_double(this, other.tensorHandle)!!
|
||||||
|
@ -56,9 +56,14 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU):
|
|||||||
val tensorA = full(value = -4.5, shape = shape, device = device)
|
val tensorA = full(value = -4.5, shape = shape, device = device)
|
||||||
val tensorB = full(value = 10.9, shape = shape, device = device)
|
val tensorB = full(value = 10.9, shape = shape, device = device)
|
||||||
val tensorC = full(value = 789.3, shape = shape, device = device)
|
val tensorC = full(value = 789.3, shape = shape, device = device)
|
||||||
val result = 15.8 * tensorA - 1.5 * tensorB + 0.02 * tensorC
|
val tensorD = full(value = -72.9, shape = shape, device = device)
|
||||||
|
val tensorE = full(value = 553.1, shape = shape, device = device)
|
||||||
|
val result = 15.8 * tensorA - 1.5 * tensorB * (-tensorD) + 0.02 * tensorC / tensorE - 39.4
|
||||||
val expected = copyFromArray(
|
val expected = copyFromArray(
|
||||||
array = (1..3).map { 15.8 * (-4.5) - 1.5 * 10.9 + 0.02 * 789.3 }.toDoubleArray(),
|
array = (1..3).map {
|
||||||
|
15.8 * (-4.5) - 1.5 * 10.9 * 72.9 + 0.02 * 789.3 / 553.1 - 39.4
|
||||||
|
}
|
||||||
|
.toDoubleArray(),
|
||||||
shape = shape,
|
shape = shape,
|
||||||
device = device
|
device = device
|
||||||
)
|
)
|
||||||
@ -66,10 +71,13 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU):
|
|||||||
val assignResult = full(value = 0.0, shape = shape, device = device)
|
val assignResult = full(value = 0.0, shape = shape, device = device)
|
||||||
tensorA *= 15.8
|
tensorA *= 15.8
|
||||||
tensorB *= 1.5
|
tensorB *= 1.5
|
||||||
|
tensorB *= -tensorD
|
||||||
tensorC *= 0.02
|
tensorC *= 0.02
|
||||||
|
tensorC /= tensorE
|
||||||
assignResult += tensorA
|
assignResult += tensorA
|
||||||
assignResult -= tensorB
|
assignResult -= tensorB
|
||||||
assignResult += tensorC
|
assignResult += tensorC
|
||||||
|
assignResult += -39.4
|
||||||
|
|
||||||
val error = (expected - result).abs().sum().value() +
|
val error = (expected - result).abs().sum().value() +
|
||||||
(expected - assignResult).abs().sum().value()
|
(expected - assignResult).abs().sum().value()
|
||||||
@ -80,15 +88,19 @@ internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU):
|
|||||||
internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||||
TorchTensorRealAlgebra {
|
TorchTensorRealAlgebra {
|
||||||
setSeed(SEED)
|
setSeed(SEED)
|
||||||
val x = randNormal(shape = intArrayOf(dim), device = device)
|
val tensorX = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
x.requiresGrad = true
|
tensorX.requiresGrad = true
|
||||||
val X = randNormal(shape = intArrayOf(dim,dim), device = device)
|
val randFeatures = randNormal(shape = intArrayOf(dim, dim), device = device)
|
||||||
val Q = X + X.transpose(0,1)
|
val tensorSigma = randFeatures + randFeatures.transpose(0,1)
|
||||||
val mu = randNormal(shape = intArrayOf(dim), device = device)
|
val tensorMu = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
val c = randNormal(shape = IntArray(0), device = device)
|
|
||||||
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + c
|
val expressionAtX =
|
||||||
val gradf = f grad x
|
0.5 * (tensorX dot (tensorSigma dot tensorX)) + (tensorMu dot tensorX) + 25.9
|
||||||
val error = (gradf - ((Q dot x) + mu)).abs().sum().value()
|
|
||||||
|
val gradientAtX = expressionAtX grad tensorX
|
||||||
|
val expectedGradientAtX = (tensorSigma dot tensorX) + tensorMu
|
||||||
|
|
||||||
|
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
|
||||||
assertTrue(error < TOLERANCE)
|
assertTrue(error < TOLERANCE)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user