Change tensor algebra operations to prefix notations #479

Open
altavir wants to merge 1 commits from refactor/tensor-prefix into dev
5 changed files with 125 additions and 107 deletions
Showing only changes of commit bc8ffd02c2 - Show all commits

View File

@ -48,7 +48,7 @@ fun reluDer(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
class ReLU : Activation(::relu, ::reluDer) class ReLU : Activation(::relu, ::reluDer)
fun sigmoid(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra { fun sigmoid(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
1.0 / (1.0 + (-x).exp()) 1.0 / (1.0 + exp(-x))
} }
fun sigmoidDer(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra { fun sigmoidDer(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
@ -116,7 +116,7 @@ class NeuralNetwork(private val layers: List<Layer>) {
onesForAnswers[intArrayOf(index, label)] = 1.0 onesForAnswers[intArrayOf(index, label)] = 1.0
} }
val softmaxValue = yPred.exp() / yPred.exp().sum(dim = 1, keepDim = true) val softmaxValue = exp(yPred) / exp(yPred).sum(dim = 1, keepDim = true)
(-onesForAnswers + softmaxValue) / (yPred.shape[0].toDouble()) (-onesForAnswers + softmaxValue) / (yPred.shape[0].toDouble())
} }

View File

@ -114,27 +114,29 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
override fun StructureND<T>.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> = override fun StructureND<T>.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.mean(keepDim, dim).wrap() ndArray.mean(keepDim, dim).wrap()
override fun StructureND<T>.exp(): Nd4jArrayStructure<T> = Transforms.exp(ndArray).wrap()
override fun StructureND<T>.ln(): Nd4jArrayStructure<T> = Transforms.log(ndArray).wrap()
override fun StructureND<T>.sqrt(): Nd4jArrayStructure<T> = Transforms.sqrt(ndArray).wrap() override fun StructureND<T>.sqrt(): Nd4jArrayStructure<T> = Transforms.sqrt(ndArray).wrap()
override fun StructureND<T>.cos(): Nd4jArrayStructure<T> = Transforms.cos(ndArray).wrap()
override fun StructureND<T>.acos(): Nd4jArrayStructure<T> = Transforms.acos(ndArray).wrap()
override fun StructureND<T>.cosh(): Nd4jArrayStructure<T> = Transforms.cosh(ndArray).wrap()
override fun StructureND<T>.acosh(): Nd4jArrayStructure<T> = override fun exp(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.exp(arg.ndArray).wrap()
Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap() override fun ln(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.log(arg.ndArray).wrap()
override fun cos(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.cos(arg.ndArray).wrap()
override fun acos(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.acos(arg.ndArray).wrap()
override fun cosh(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.cosh(arg.ndArray).wrap()
override fun StructureND<T>.sin(): Nd4jArrayStructure<T> = Transforms.sin(ndArray).wrap() override fun acosh(arg: StructureND<T>): Nd4jArrayStructure<T> =
override fun StructureND<T>.asin(): Nd4jArrayStructure<T> = Transforms.asin(ndArray).wrap() Nd4j.getExecutioner().exec(ACosh(arg.ndArray, arg.ndArray.ulike())).wrap()
override fun StructureND<T>.sinh(): Tensor<T> = Transforms.sinh(ndArray).wrap()
override fun StructureND<T>.asinh(): Nd4jArrayStructure<T> = override fun sin(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.sin(arg.ndArray).wrap()
Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap() override fun asin(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.asin(arg.ndArray).wrap()
override fun sinh(arg: StructureND<T>): Tensor<T> = Transforms.sinh(arg.ndArray).wrap()
override fun asinh(arg: StructureND<T>): Nd4jArrayStructure<T> =
Nd4j.getExecutioner().exec(ASinh(arg.ndArray, arg.ndArray.ulike())).wrap()
override fun tan(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.tan(arg.ndArray).wrap()
override fun atan(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.atan(arg.ndArray).wrap()
override fun tanh(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.tanh(arg.ndArray).wrap()
override fun atanh(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.atanh(arg.ndArray).wrap()
override fun StructureND<T>.tan(): Nd4jArrayStructure<T> = Transforms.tan(ndArray).wrap()
override fun StructureND<T>.atan(): Nd4jArrayStructure<T> = Transforms.atan(ndArray).wrap()
override fun StructureND<T>.tanh(): Nd4jArrayStructure<T> = Transforms.tanh(ndArray).wrap()
override fun StructureND<T>.atanh(): Nd4jArrayStructure<T> = Transforms.atanh(ndArray).wrap()
override fun StructureND<T>.ceil(): Nd4jArrayStructure<T> = Transforms.ceil(ndArray).wrap() override fun StructureND<T>.ceil(): Nd4jArrayStructure<T> = Transforms.ceil(ndArray).wrap()
override fun StructureND<T>.floor(): Nd4jArrayStructure<T> = Transforms.floor(ndArray).wrap() override fun StructureND<T>.floor(): Nd4jArrayStructure<T> = Transforms.floor(ndArray).wrap()
override fun StructureND<T>.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> = override fun StructureND<T>.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
@ -178,7 +180,6 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, DoubleField> {
return array.wrap() return array.wrap()
} }
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
override val StructureND<Double>.ndArray: INDArray override val StructureND<Double>.ndArray: INDArray
get() = when (this) { get() = when (this) {

View File

@ -72,78 +72,95 @@ public interface AnalyticTensorAlgebra<T, A : Field<T>> :
*/ */
public fun StructureND<T>.variance(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.variance(dim: Int, keepDim: Boolean): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.exp.html override fun sin(arg: StructureND<T>): Tensor<T>
public fun StructureND<T>.exp(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.log.html override fun cos(arg: StructureND<T>): Tensor<T>
public fun StructureND<T>.ln(): Tensor<T>
override fun asin(arg: StructureND<T>): Tensor<T>
override fun acos(arg: StructureND<T>): Tensor<T>
override fun tan(arg: StructureND<T>): Tensor<T>
override fun atan(arg: StructureND<T>): Tensor<T>
override fun exp(arg: StructureND<T>): Tensor<T>
override fun ln(arg: StructureND<T>): Tensor<T>
override fun sinh(arg: StructureND<T>): Tensor<T>
override fun cosh(arg: StructureND<T>): Tensor<T>
override fun tanh(arg: StructureND<T>): Tensor<T>
override fun asinh(arg: StructureND<T>): Tensor<T>
override fun acosh(arg: StructureND<T>): Tensor<T>
override fun atanh(arg: StructureND<T>): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html //For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html
public fun StructureND<T>.sqrt(): Tensor<T> public fun StructureND<T>.sqrt(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos
public fun StructureND<T>.cos(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos
public fun StructureND<T>.acos(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh
public fun StructureND<T>.cosh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh
public fun StructureND<T>.acosh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin
public fun StructureND<T>.sin(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin
public fun StructureND<T>.asin(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh
public fun StructureND<T>.sinh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh
public fun StructureND<T>.asinh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan
public fun StructureND<T>.tan(): Tensor<T>
//https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan
public fun StructureND<T>.atan(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh
public fun StructureND<T>.tanh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh
public fun StructureND<T>.atanh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil //For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil
public fun StructureND<T>.ceil(): Tensor<T> public fun StructureND<T>.ceil(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor //For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor
public fun StructureND<T>.floor(): Tensor<T> public fun StructureND<T>.floor(): Tensor<T>
override fun sin(arg: StructureND<T>): Tensor<T> = arg.sin() // //For information: https://pytorch.org/docs/stable/generated/torch.exp.html
// public fun StructureND<T>.exp(): Tensor<T> = exp(this)
override fun cos(arg: StructureND<T>): Tensor<T> = arg.cos() //
// //For information: https://pytorch.org/docs/stable/generated/torch.log.html
override fun asin(arg: StructureND<T>): Tensor<T> = arg.asin() // @JvmName("lnChain")
// public fun StructureND<T>.ln(): Tensor<T> = ln(this)
override fun acos(arg: StructureND<T>): Tensor<T> = arg.acos() //
// //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos
override fun atan(arg: StructureND<T>): Tensor<T> = arg.atan() // @JvmName("cosChain")
// public fun StructureND<T>.cos(): Tensor<T> = cos(this)
override fun exp(arg: StructureND<T>): Tensor<T> = arg.exp() //
// //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos
override fun ln(arg: StructureND<T>): Tensor<T> = arg.ln() // @JvmName("acosChain")
// public fun StructureND<T>.acos(): Tensor<T> = acos(this)
override fun sinh(arg: StructureND<T>): Tensor<T> = arg.sinh() //
// //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh
override fun cosh(arg: StructureND<T>): Tensor<T> = arg.cosh() // @JvmName("coshChain")
// public fun StructureND<T>.cosh(): Tensor<T> = cosh(this)
override fun asinh(arg: StructureND<T>): Tensor<T> = arg.asinh() //
// //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh
override fun acosh(arg: StructureND<T>): Tensor<T> = arg.acosh() // @JvmName("acoshChain")
// public fun StructureND<T>.acosh(): Tensor<T> = acosh(this)
override fun atanh(arg: StructureND<T>): Tensor<T> = arg.atanh() //
// //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin
// @JvmName("sinChain")
// public fun StructureND<T>.sin(): Tensor<T> = sin(this)
//
// //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin
// @JvmName("asinChain")
// public fun StructureND<T>.asin(): Tensor<T> = asin(this)
//
// //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh
// @JvmName("sinhChain")
// public fun StructureND<T>.sinh(): Tensor<T> = sinh(this)
//
// //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh
// @JvmName("asinhChain")
// public fun StructureND<T>.asinh(): Tensor<T> = asinh(this)
//
// //For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan
// @JvmName("tanChain")
// public fun StructureND<T>.tan(): Tensor<T> = tan(this)
//
// //https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan
// @JvmName("atanChain")
// public fun StructureND<T>.atan(): Tensor<T> = atan(this)
//
// //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh
// @JvmName("tanhChain")
// public fun StructureND<T>.tanh(): Tensor<T> = tanh(this)
//
// //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh
// @JvmName("atanhChain")
// public fun StructureND<T>.atanh(): Tensor<T> = atanh(this)
} }

View File

@ -685,35 +685,35 @@ public open class DoubleTensorAlgebra :
return resTensor return resTensor
} }
override fun StructureND<Double>.exp(): DoubleTensor = tensor.map { exp(it) } override fun exp(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { exp(it) }
override fun StructureND<Double>.ln(): DoubleTensor = tensor.map { ln(it) } override fun ln(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { ln(it) }
override fun StructureND<Double>.sqrt(): DoubleTensor = tensor.map { sqrt(it) } override fun StructureND<Double>.sqrt(): DoubleTensor = tensor.map { sqrt(it) }
override fun StructureND<Double>.cos(): DoubleTensor = tensor.map { cos(it) } override fun cos(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { cos(it) }
override fun StructureND<Double>.acos(): DoubleTensor = tensor.map { acos(it) } override fun acos(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { acos(it) }
override fun StructureND<Double>.cosh(): DoubleTensor = tensor.map { cosh(it) } override fun cosh(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { cosh(it) }
override fun StructureND<Double>.acosh(): DoubleTensor = tensor.map { acosh(it) } override fun acosh(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { acosh(it) }
override fun StructureND<Double>.sin(): DoubleTensor = tensor.map { sin(it) } override fun sin(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { sin(it) }
override fun StructureND<Double>.asin(): DoubleTensor = tensor.map { asin(it) } override fun asin(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { asin(it) }
override fun StructureND<Double>.sinh(): DoubleTensor = tensor.map { sinh(it) } override fun sinh(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { sinh(it) }
override fun StructureND<Double>.asinh(): DoubleTensor = tensor.map { asinh(it) } override fun asinh(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { asinh(it) }
override fun StructureND<Double>.tan(): DoubleTensor = tensor.map { tan(it) } override fun tan(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { tan(it) }
override fun StructureND<Double>.atan(): DoubleTensor = tensor.map { atan(it) } override fun atan(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { atan(it) }
override fun StructureND<Double>.tanh(): DoubleTensor = tensor.map { tanh(it) } override fun tanh(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { tanh(it) }
override fun StructureND<Double>.atanh(): DoubleTensor = tensor.map { atanh(it) } override fun atanh(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { atanh(it) }
override fun StructureND<Double>.ceil(): DoubleTensor = tensor.map { ceil(it) } override fun StructureND<Double>.ceil(): DoubleTensor = tensor.map { ceil(it) }

View File

@ -32,12 +32,12 @@ internal class TestDoubleAnalyticTensorAlgebra {
@Test @Test
fun testExp() = DoubleTensorAlgebra { fun testExp() = DoubleTensorAlgebra {
assertTrue { tensor.exp() eq expectedTensor(::exp) } assertTrue { exp(tensor) eq expectedTensor(::exp) }
} }
@Test @Test
fun testLog() = DoubleTensorAlgebra { fun testLog() = DoubleTensorAlgebra {
assertTrue { tensor.ln() eq expectedTensor(::ln) } assertTrue { ln(tensor) eq expectedTensor(::ln) }
} }
@Test @Test
@ -47,48 +47,48 @@ internal class TestDoubleAnalyticTensorAlgebra {
@Test @Test
fun testCos() = DoubleTensorAlgebra { fun testCos() = DoubleTensorAlgebra {
assertTrue { tensor.cos() eq expectedTensor(::cos) } assertTrue { cos(tensor) eq expectedTensor(::cos) }
} }
@Test @Test
fun testCosh() = DoubleTensorAlgebra { fun testCosh() = DoubleTensorAlgebra {
assertTrue { tensor.cosh() eq expectedTensor(::cosh) } assertTrue { cosh(tensor) eq expectedTensor(::cosh) }
} }
@Test @Test
fun testAcosh() = DoubleTensorAlgebra { fun testAcosh() = DoubleTensorAlgebra {
assertTrue { tensor.acosh() eq expectedTensor(::acosh) } assertTrue { acosh(tensor) eq expectedTensor(::acosh) }
} }
@Test @Test
fun testSin() = DoubleTensorAlgebra { fun testSin() = DoubleTensorAlgebra {
assertTrue { tensor.sin() eq expectedTensor(::sin) } assertTrue { sin(tensor) eq expectedTensor(::sin) }
} }
@Test @Test
fun testSinh() = DoubleTensorAlgebra { fun testSinh() = DoubleTensorAlgebra {
assertTrue { tensor.sinh() eq expectedTensor(::sinh) } assertTrue { sinh(tensor) eq expectedTensor(::sinh) }
} }
@Test @Test
fun testAsinh() = DoubleTensorAlgebra { fun testAsinh() = DoubleTensorAlgebra {
assertTrue { tensor.asinh() eq expectedTensor(::asinh) } assertTrue { asinh(tensor) eq expectedTensor(::asinh) }
} }
@Test @Test
fun testTan() = DoubleTensorAlgebra { fun testTan() = DoubleTensorAlgebra {
assertTrue { tensor.tan() eq expectedTensor(::tan) } assertTrue { tan(tensor) eq expectedTensor(::tan) }
} }
@Test @Test
fun testAtan() = DoubleTensorAlgebra { fun testAtan() = DoubleTensorAlgebra {
assertTrue { tensor.atan() eq expectedTensor(::atan) } assertTrue { atan(tensor) eq expectedTensor(::atan) }
} }
@Test @Test
fun testTanh() = DoubleTensorAlgebra { fun testTanh() = DoubleTensorAlgebra {
assertTrue { tensor.tanh() eq expectedTensor(::tanh) } assertTrue { tanh(tensor) eq expectedTensor(::tanh) }
} }
@Test @Test