Code review.

This commit is contained in:
Alexander Nozik 2021-03-10 18:02:04 +03:00
parent 8ae8ebe871
commit 6a5ca2a115
9 changed files with 58 additions and 74 deletions

View File

@ -2,8 +2,8 @@
## [Unreleased] ## [Unreleased]
### Added ### Added
- Intrinsic value `two` for ExtendedField to work with hyperbolic functions
- ScaleOperations interface - ScaleOperations interface
- Field extends ScaleOperations
### Changed ### Changed
- Exponential operations merged with hyperbolic functions - Exponential operations merged with hyperbolic functions

View File

@ -34,6 +34,7 @@ readme {
ksciencePublish { ksciencePublish {
github("kmath") github("kmath")
space() space()
sonatype()
} }
apiValidation { apiValidation {

View File

@ -1,12 +1,11 @@
package space.kscience.kmath.ast package space.kscience.kmath.ast
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.RealField import space.kscience.kmath.operations.RealField
fun main() { fun main() {
val expr = RealField.mstInField { val expr = RealField.mstInField {
val x = bindSymbol("x") val x = bindSymbol("x")
x * 2.0 + 2.0 * one / x - 16.0 x * 2.0 + number(2.0) / x - 16.0
} }
repeat(10000000) { repeat(10000000) {

View File

@ -1,8 +1,9 @@
package space.kscience.kmath.linear package space.kscience.kmath.linear
import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.structures.RealBuffer import space.kscience.kmath.structures.RealBuffer
public object RealMatrixContext : MatrixContext<Double, BufferMatrix<Double>> { public object RealMatrixContext : MatrixContext<Double, BufferMatrix<Double>>, ScaleOperations<Matrix<Double>> {
public override fun produce( public override fun produce(
rows: Int, rows: Int,
@ -58,10 +59,13 @@ public object RealMatrixContext : MatrixContext<Double, BufferMatrix<Double>> {
} }
} }
override fun Matrix<Double>.times(value: Double): BufferMatrix<Double> { override fun scale(a: Matrix<Double>, value: Double): BufferMatrix<Double> {
val bufferMatrix = toBufferMatrix() val bufferMatrix = a.toBufferMatrix()
return produce(rowNum, colNum) { i, j -> bufferMatrix[i, j] * value } return produce(a.rowNum, a.colNum) { i, j -> bufferMatrix[i, j] * value }
} }
override fun Matrix<Double>.times(value: Double): BufferMatrix<Double> = scale(this, value)
// //
// override fun multiply(a: Matrix<Double>, k: Number): BufferMatrix<Double> { // override fun multiply(a: Matrix<Double>, k: Number): BufferMatrix<Double> {
// val aBufferMatrix = a.toBufferMatrix() // val aBufferMatrix = a.toBufferMatrix()

View File

@ -86,7 +86,7 @@ public interface NumericAlgebra<T> : Algebra<T> {
*/ */
public interface ScaleOperations<T> : Algebra<T> { public interface ScaleOperations<T> : Algebra<T> {
/** /**
* Scaling of element by scalar. * Scaling an element by a scalar.
* *
* @param a the multiplier. * @param a the multiplier.
* @param value the multiplicand. * @param value the multiplicand.

View File

@ -11,8 +11,8 @@ import space.kscience.kmath.structures.MutableBufferFactory
* Generic spline interpolator. Not recommended for performance critical places, use platform-specific and type specific ones. * Generic spline interpolator. Not recommended for performance critical places, use platform-specific and type specific ones.
* Based on https://github.com/apache/commons-math/blob/eb57d6d457002a0bb5336d789a3381a24599affe/src/main/java/org/apache/commons/math4/analysis/interpolation/SplineInterpolator.java * Based on https://github.com/apache/commons-math/blob/eb57d6d457002a0bb5336d789a3381a24599affe/src/main/java/org/apache/commons/math4/analysis/interpolation/SplineInterpolator.java
*/ */
public class SplineInterpolator<T : Comparable<T>, F : Field<T>>( public class SplineInterpolator<T : Comparable<T>>(
public override val algebra: F, public override val algebra: Field<T>,
public val bufferFactory: MutableBufferFactory<T>, public val bufferFactory: MutableBufferFactory<T>,
) : PolynomialInterpolator<T> { ) : PolynomialInterpolator<T> {
//TODO possibly optimize zeroed buffers //TODO possibly optimize zeroed buffers

View File

@ -80,21 +80,17 @@ public interface Nd4jArraySpace<T, S : Space<T>> : NDSpace<T, S>, Nd4jArrayAlgeb
public override val zero: Nd4jArrayStructure<T> public override val zero: Nd4jArrayStructure<T>
get() = Nd4j.zeros(*shape).wrap() get() = Nd4j.zeros(*shape).wrap()
public override fun add(a: NDStructure<T>, b: NDStructure<T>): Nd4jArrayStructure<T> { public override fun add(a: NDStructure<T>, b: NDStructure<T>): Nd4jArrayStructure<T> =
return a.ndArray.add(b.ndArray).wrap() a.ndArray.add(b.ndArray).wrap()
}
public override operator fun NDStructure<T>.minus(b: NDStructure<T>): Nd4jArrayStructure<T> { public override operator fun NDStructure<T>.minus(b: NDStructure<T>): Nd4jArrayStructure<T> =
return ndArray.sub(b.ndArray).wrap() ndArray.sub(b.ndArray).wrap()
}
public override operator fun NDStructure<T>.unaryMinus(): Nd4jArrayStructure<T> { public override operator fun NDStructure<T>.unaryMinus(): Nd4jArrayStructure<T> =
return ndArray.neg().wrap() ndArray.neg().wrap()
}
public fun multiply(a: NDStructure<T>, k: Number): Nd4jArrayStructure<T> { public fun multiply(a: NDStructure<T>, k: Number): Nd4jArrayStructure<T> =
return a.ndArray.mul(k).wrap() a.ndArray.mul(k).wrap()
}
} }
/** /**
@ -109,9 +105,8 @@ public interface Nd4jArrayRing<T, R : Ring<T>> : NDRing<T, R>, Nd4jArraySpace<T,
public override val one: Nd4jArrayStructure<T> public override val one: Nd4jArrayStructure<T>
get() = Nd4j.ones(*shape).wrap() get() = Nd4j.ones(*shape).wrap()
public override fun multiply(a: NDStructure<T>, b: NDStructure<T>): Nd4jArrayStructure<T> { public override fun multiply(a: NDStructure<T>, b: NDStructure<T>): Nd4jArrayStructure<T> =
return a.ndArray.mul(b.ndArray).wrap() a.ndArray.mul(b.ndArray).wrap()
}
// //
// public override operator fun Nd4jArrayStructure<T>.minus(b: Number): Nd4jArrayStructure<T> { // public override operator fun Nd4jArrayStructure<T>.minus(b: Number): Nd4jArrayStructure<T> {
// check(this) // check(this)
@ -250,33 +245,26 @@ public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArra
public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure() public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure()
override fun scale(a: NDStructure<Float>, value: Double): NDStructure<Float> { override fun scale(a: NDStructure<Float>, value: Double): NDStructure<Float> =
return a.ndArray.mul(value).wrap() a.ndArray.mul(value).wrap()
}
public override operator fun NDStructure<Float>.div(arg: Float): Nd4jArrayStructure<Float> { public override operator fun NDStructure<Float>.div(arg: Float): Nd4jArrayStructure<Float> =
return ndArray.div(arg).wrap() ndArray.div(arg).wrap()
}
public override operator fun NDStructure<Float>.plus(arg: Float): Nd4jArrayStructure<Float> { public override operator fun NDStructure<Float>.plus(arg: Float): Nd4jArrayStructure<Float> =
return ndArray.add(arg).wrap() ndArray.add(arg).wrap()
}
public override operator fun NDStructure<Float>.minus(arg: Float): Nd4jArrayStructure<Float> { public override operator fun NDStructure<Float>.minus(arg: Float): Nd4jArrayStructure<Float> =
return ndArray.sub(arg).wrap() ndArray.sub(arg).wrap()
}
public override operator fun NDStructure<Float>.times(arg: Float): Nd4jArrayStructure<Float> { public override operator fun NDStructure<Float>.times(arg: Float): Nd4jArrayStructure<Float> =
return ndArray.mul(arg).wrap() ndArray.mul(arg).wrap()
}
public override operator fun Float.div(arg: NDStructure<Float>): Nd4jArrayStructure<Float> { public override operator fun Float.div(arg: NDStructure<Float>): Nd4jArrayStructure<Float> =
return arg.ndArray.rdiv(this).wrap() arg.ndArray.rdiv(this).wrap()
}
public override operator fun Float.minus(arg: NDStructure<Float>): Nd4jArrayStructure<Float> { public override operator fun Float.minus(arg: NDStructure<Float>): Nd4jArrayStructure<Float> =
return arg.ndArray.rsub(this).wrap() arg.ndArray.rsub(this).wrap()
}
} }
/** /**
@ -288,21 +276,17 @@ public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRi
public override fun INDArray.wrap(): Nd4jArrayStructure<Int> = checkShape(this).asIntStructure() public override fun INDArray.wrap(): Nd4jArrayStructure<Int> = checkShape(this).asIntStructure()
public override operator fun NDStructure<Int>.plus(arg: Int): Nd4jArrayStructure<Int> { public override operator fun NDStructure<Int>.plus(arg: Int): Nd4jArrayStructure<Int> =
return ndArray.add(arg).wrap() ndArray.add(arg).wrap()
}
public override operator fun NDStructure<Int>.minus(arg: Int): Nd4jArrayStructure<Int> { public override operator fun NDStructure<Int>.minus(arg: Int): Nd4jArrayStructure<Int> =
return ndArray.sub(arg).wrap() ndArray.sub(arg).wrap()
}
public override operator fun NDStructure<Int>.times(arg: Int): Nd4jArrayStructure<Int> { public override operator fun NDStructure<Int>.times(arg: Int): Nd4jArrayStructure<Int> =
return ndArray.mul(arg).wrap() ndArray.mul(arg).wrap()
}
public override operator fun Int.minus(arg: NDStructure<Int>): Nd4jArrayStructure<Int> { public override operator fun Int.minus(arg: NDStructure<Int>): Nd4jArrayStructure<Int> =
return arg.ndArray.rsub(this).wrap() arg.ndArray.rsub(this).wrap()
}
} }
/** /**
@ -314,19 +298,15 @@ public class LongNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayR
public override fun INDArray.wrap(): Nd4jArrayStructure<Long> = checkShape(this).asLongStructure() public override fun INDArray.wrap(): Nd4jArrayStructure<Long> = checkShape(this).asLongStructure()
public override operator fun NDStructure<Long>.plus(arg: Long): Nd4jArrayStructure<Long> { public override operator fun NDStructure<Long>.plus(arg: Long): Nd4jArrayStructure<Long> =
return ndArray.add(arg).wrap() ndArray.add(arg).wrap()
}
public override operator fun NDStructure<Long>.minus(arg: Long): Nd4jArrayStructure<Long> { public override operator fun NDStructure<Long>.minus(arg: Long): Nd4jArrayStructure<Long> =
return ndArray.sub(arg).wrap() ndArray.sub(arg).wrap()
}
public override operator fun NDStructure<Long>.times(arg: Long): Nd4jArrayStructure<Long> { public override operator fun NDStructure<Long>.times(arg: Long): Nd4jArrayStructure<Long> =
return ndArray.mul(arg).wrap() ndArray.mul(arg).wrap()
}
public override operator fun Long.minus(arg: NDStructure<Long>): Nd4jArrayStructure<Long> { public override operator fun Long.minus(arg: NDStructure<Long>): Nd4jArrayStructure<Long> =
return arg.ndArray.rsub(this).wrap() arg.ndArray.rsub(this).wrap()
}
} }

View File

@ -57,7 +57,7 @@ public class ViktorNDField(public override val shape: IntArray) : NDField<Double
} }
}.asStructure() }.asStructure()
override fun NDStructure<Double>.unaryMinus(): NDStructure<Double> = this * (-1) override fun NDStructure<Double>.unaryMinus(): NDStructure<Double> = -1 * this
public override fun NDStructure<Double>.map(transform: RealField.(Double) -> Double): ViktorNDStructure = public override fun NDStructure<Double>.map(transform: RealField.(Double) -> Double): ViktorNDStructure =
F64Array(*this@ViktorNDField.shape).apply { F64Array(*this@ViktorNDField.shape).apply {