forked from kscience/kmath
basic algebra ops
This commit is contained in:
parent
2a97aca9b6
commit
623c96e4bb
@ -1,6 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
* Copyright 2018-2021 KMath contributors.
|
* Copyright 2018-2021 KMath contributors.
|
||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
* Use of this.cast() source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package space.kscience.kmath.noa
|
package space.kscience.kmath.noa
|
||||||
@ -20,18 +20,56 @@ constructor(protected val scope: NoaScope) : TensorAlgebra<T> {
|
|||||||
protected abstract fun wrap(tensorHandle: TensorHandle): TensorType
|
protected abstract fun wrap(tensorHandle: TensorHandle): TensorType
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A scalar tensor in this implementation must have empty shape
|
* A scalar tensor in this.cast() implementation must have empty shape
|
||||||
*/
|
*/
|
||||||
override fun Tensor<T>.valueOrNull(): T? =
|
override fun Tensor<T>.valueOrNull(): T? =
|
||||||
try { this.cast().item() } catch (e: NoaException) { null }
|
try {
|
||||||
|
this.cast().cast().item()
|
||||||
|
} catch (e: NoaException) {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.value(): T = this.cast().item()
|
override fun Tensor<T>.value(): T = this.cast().cast().item()
|
||||||
|
|
||||||
|
override operator fun Tensor<T>.times(other: Tensor<T>): TensorType {
|
||||||
|
return wrap(JNoa.timesTensor(this.cast().tensorHandle, other.cast().tensorHandle))
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Tensor<T>.timesAssign(other: Tensor<T>): Unit {
|
||||||
|
JNoa.timesTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Tensor<T>.plus(other: Tensor<T>): TensorType {
|
||||||
|
return wrap(JNoa.plusTensor(this.cast().tensorHandle, other.cast().tensorHandle))
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Tensor<T>.plusAssign(other: Tensor<T>): Unit {
|
||||||
|
JNoa.plusTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Tensor<T>.minus(other: Tensor<T>): TensorType {
|
||||||
|
return wrap(JNoa.minusTensor(this.cast().tensorHandle, other.cast().tensorHandle))
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Tensor<T>.minusAssign(other: Tensor<T>): Unit {
|
||||||
|
JNoa.minusTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Tensor<T>.unaryMinus(): TensorType =
|
||||||
|
wrap(JNoa.unaryMinus(this.cast().tensorHandle))
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>>
|
public abstract class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>>
|
||||||
internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), LinearOpsTensorAlgebra<T>,
|
internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), LinearOpsTensorAlgebra<T>,
|
||||||
AnalyticTensorAlgebra<T> {
|
AnalyticTensorAlgebra<T> {
|
||||||
|
|
||||||
|
override operator fun Tensor<T>.div(other: Tensor<T>): TensorType {
|
||||||
|
return wrap(JNoa.divTensor(this.cast().tensorHandle, other.cast().tensorHandle))
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Tensor<T>.divAssign(other: Tensor<T>): Unit {
|
||||||
|
JNoa.divTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user