forked from kscience/kmath
Naive classifier notebook
This commit is contained in:
parent
991ab907d8
commit
29977650f1
3
.gitignore
vendored
3
.gitignore
vendored
@ -3,10 +3,9 @@ build/
|
|||||||
out/
|
out/
|
||||||
|
|
||||||
.idea/
|
.idea/
|
||||||
|
|
||||||
|
|
||||||
.vscode/
|
.vscode/
|
||||||
|
|
||||||
|
|
||||||
# Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored)
|
# Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored)
|
||||||
!gradle-wrapper.jar
|
!gradle-wrapper.jar
|
||||||
|
|
||||||
|
418
examples/notebooks/Naive classifier.ipynb
Normal file
418
examples/notebooks/Naive classifier.ipynb
Normal file
@ -0,0 +1,418 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"%use kmath(0.3.1-dev-5)\n",
|
||||||
|
"%use plotly(0.5.0)\n",
|
||||||
|
"@file:DependsOn(\"space.kscience:kmath-commons:0.3.1-dev-5\")"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"datalore": {
|
||||||
|
"node_id": "lQbSB87rNAn9lV6poArVWW",
|
||||||
|
"type": "CODE",
|
||||||
|
"hide_input_from_viewers": false,
|
||||||
|
"hide_output_from_viewers": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"//Uncomment to work in Jupyter classic or DataLore\n",
|
||||||
|
"//Plotly.jupyter.notebook()"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"datalore": {
|
||||||
|
"node_id": "0UP158hfccGgjQtHz0wAi6",
|
||||||
|
"type": "CODE",
|
||||||
|
"hide_input_from_viewers": false,
|
||||||
|
"hide_output_from_viewers": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# The model\n",
|
||||||
|
"\n",
|
||||||
|
"Defining the input data format, the statistic abstraction and the statistic implementation based on a weighted sum of elements."
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"class XYValues(val xValues: DoubleArray, val yValues: DoubleArray) {\n",
|
||||||
|
" init {\n",
|
||||||
|
" require(xValues.size == yValues.size)\n",
|
||||||
|
" }\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"fun interface XYStatistic {\n",
|
||||||
|
" operator fun invoke(values: XYValues): Double\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"class ConvolutionalXYStatistic(val weights: DoubleArray) : XYStatistic {\n",
|
||||||
|
" override fun invoke(values: XYValues): Double {\n",
|
||||||
|
" require(weights.size == values.yValues.size)\n",
|
||||||
|
" val norm = values.yValues.sum()\n",
|
||||||
|
" return values.yValues.zip(weights) { value, weight -> value * weight }.sum()/norm\n",
|
||||||
|
" }\n",
|
||||||
|
"}"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"datalore": {
|
||||||
|
"node_id": "Zhgz1Ui91PWz0meJiQpHol",
|
||||||
|
"type": "CODE",
|
||||||
|
"hide_input_from_viewers": false,
|
||||||
|
"hide_output_from_viewers": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# Generator\n",
|
||||||
|
"Generate sample data for parabolas and hyperbolas"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"fun generateParabolas(xValues: DoubleArray, a: Double, b: Double, c: Double): XYValues {\n",
|
||||||
|
" val yValues = xValues.map { x -> a * x * x + b * x + c }.toDoubleArray()\n",
|
||||||
|
" return XYValues(xValues, yValues)\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"fun generateHyperbols(xValues: DoubleArray, gamma: Double, x0: Double, y0: Double): XYValues {\n",
|
||||||
|
" val yValues = xValues.map { x -> y0 + gamma / (x - x0) }.toDoubleArray()\n",
|
||||||
|
" return XYValues(xValues, yValues)\n",
|
||||||
|
"}"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"val xValues = (1.0..10.0).step(1.0).toDoubleArray()\n",
|
||||||
|
"\n",
|
||||||
|
"val xy = generateHyperbols(xValues, 1.0, 0.0, 0.0)\n",
|
||||||
|
"\n",
|
||||||
|
"Plotly.plot {\n",
|
||||||
|
" scatter {\n",
|
||||||
|
" this.x.doubles = xValues\n",
|
||||||
|
" this.y.doubles = xy.yValues\n",
|
||||||
|
" }\n",
|
||||||
|
"}"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"datalore": {
|
||||||
|
"node_id": "ZE2atNvFzQsCvpAF8KK4ch",
|
||||||
|
"type": "CODE",
|
||||||
|
"hide_input_from_viewers": false,
|
||||||
|
"hide_output_from_viewers": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"Create a default statistic with uniform weights"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"val statistic = ConvolutionalXYStatistic(DoubleArray(xValues.size){1.0})\n",
|
||||||
|
"statistic(xy)"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"datalore": {
|
||||||
|
"node_id": "EA5HaydTddRKYrtAUwd29h",
|
||||||
|
"type": "CODE",
|
||||||
|
"hide_input_from_viewers": false,
|
||||||
|
"hide_output_from_viewers": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"import kotlin.random.Random\n",
|
||||||
|
"\n",
|
||||||
|
"val random = Random(1288)\n",
|
||||||
|
"\n",
|
||||||
|
"val parabolas = buildList{\n",
|
||||||
|
" repeat(500){\n",
|
||||||
|
" add(\n",
|
||||||
|
" generateParabolas(\n",
|
||||||
|
" xValues, \n",
|
||||||
|
" random.nextDouble(), \n",
|
||||||
|
" random.nextDouble(), \n",
|
||||||
|
" random.nextDouble()\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" }\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"val hyperbolas: List<XYValues> = buildList{\n",
|
||||||
|
" repeat(500){\n",
|
||||||
|
" add(\n",
|
||||||
|
" generateHyperbols(\n",
|
||||||
|
" xValues, \n",
|
||||||
|
" random.nextDouble()*10, \n",
|
||||||
|
" random.nextDouble(), \n",
|
||||||
|
" random.nextDouble()\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" }\n",
|
||||||
|
"}"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"datalore": {
|
||||||
|
"node_id": "t5t6IYmD7Q1ykeo9uijFfQ",
|
||||||
|
"type": "CODE",
|
||||||
|
"hide_input_from_viewers": false,
|
||||||
|
"hide_output_from_viewers": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"Plotly.plot { \n",
|
||||||
|
" scatter { \n",
|
||||||
|
" x.doubles = xValues\n",
|
||||||
|
" y.doubles = parabolas[257].yValues\n",
|
||||||
|
" }\n",
|
||||||
|
" scatter { \n",
|
||||||
|
" x.doubles = xValues\n",
|
||||||
|
" y.doubles = hyperbolas[252].yValues\n",
|
||||||
|
" }\n",
|
||||||
|
" }"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"datalore": {
|
||||||
|
"node_id": "oXB8lmju7YVYjMRXITKnhO",
|
||||||
|
"type": "CODE",
|
||||||
|
"hide_input_from_viewers": false,
|
||||||
|
"hide_output_from_viewers": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"Plotly.plot { \n",
|
||||||
|
" histogram { \n",
|
||||||
|
" name = \"parabolae\"\n",
|
||||||
|
" x.numbers = parabolas.map { statistic(it) }\n",
|
||||||
|
" }\n",
|
||||||
|
" histogram { \n",
|
||||||
|
" name = \"hyperbolae\"\n",
|
||||||
|
" x.numbers = hyperbolas.map { statistic(it) }\n",
|
||||||
|
" }\n",
|
||||||
|
"}"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"datalore": {
|
||||||
|
"node_id": "8EIIecUZrt2NNrOkhxG5P0",
|
||||||
|
"type": "CODE",
|
||||||
|
"hide_input_from_viewers": false,
|
||||||
|
"hide_output_from_viewers": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"val lossFunction: (XYStatistic) -> Double = { statistic ->\n",
|
||||||
|
" - abs(parabolas.sumOf { statistic(it) } - hyperbolas.sumOf { statistic(it) })\n",
|
||||||
|
"}"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"datalore": {
|
||||||
|
"node_id": "h7UmglJW5zXkAfKHK40oIL",
|
||||||
|
"type": "CODE",
|
||||||
|
"hide_input_from_viewers": false,
|
||||||
|
"hide_output_from_viewers": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"Using commons-math optimizer to optimize weights"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"import org.apache.commons.math3.optim.*\n",
|
||||||
|
"import org.apache.commons.math3.optim.nonlinear.scalar.*\n",
|
||||||
|
"import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.*\n",
|
||||||
|
"\n",
|
||||||
|
"val optimizer = SimplexOptimizer(1e-1, Double.MAX_VALUE)\n",
|
||||||
|
"\n",
|
||||||
|
"val result = optimizer.optimize(\n",
|
||||||
|
" ObjectiveFunction { point ->\n",
|
||||||
|
" lossFunction(ConvolutionalXYStatistic(point))\n",
|
||||||
|
" },\n",
|
||||||
|
" NelderMeadSimplex(xValues.size),\n",
|
||||||
|
" InitialGuess(DoubleArray(xValues.size){ 1.0 }),\n",
|
||||||
|
" GoalType.MINIMIZE,\n",
|
||||||
|
" MaxEval(100000)\n",
|
||||||
|
")"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"datalore": {
|
||||||
|
"node_id": "0EG3K4aCUciMlgGQKPvJ57",
|
||||||
|
"type": "CODE",
|
||||||
|
"hide_input_from_viewers": false,
|
||||||
|
"hide_output_from_viewers": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"Print resulting weights of optimization"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"result.point"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"datalore": {
|
||||||
|
"node_id": "LelUlY0ZSlJEO9yC6SLk5B",
|
||||||
|
"type": "CODE",
|
||||||
|
"hide_input_from_viewers": false,
|
||||||
|
"hide_output_from_viewers": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"Plotly.plot { \n",
|
||||||
|
" scatter { \n",
|
||||||
|
" y.doubles = result.point\n",
|
||||||
|
" }\n",
|
||||||
|
"}"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"datalore": {
|
||||||
|
"node_id": "AuFOq5t9KpOIkGrOLsVXNf",
|
||||||
|
"type": "CODE",
|
||||||
|
"hide_input_from_viewers": false,
|
||||||
|
"hide_output_from_viewers": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# The resulting statistic distribution"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"val resultStatistic = ConvolutionalXYStatistic(result.point)\n",
|
||||||
|
"Plotly.plot { \n",
|
||||||
|
" histogram { \n",
|
||||||
|
" name = \"parabolae\"\n",
|
||||||
|
" x.numbers = parabolas.map { resultStatistic(it) }\n",
|
||||||
|
" }\n",
|
||||||
|
" histogram { \n",
|
||||||
|
" name = \"hyperbolae\"\n",
|
||||||
|
" x.numbers = hyperbolas.map { resultStatistic(it) }\n",
|
||||||
|
" }\n",
|
||||||
|
"}"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"metadata": {
|
||||||
|
"datalore": {
|
||||||
|
"node_id": "zvmq42DRdM5mZ3SpzviHwI",
|
||||||
|
"type": "CODE",
|
||||||
|
"hide_input_from_viewers": false,
|
||||||
|
"hide_output_from_viewers": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Kotlin",
|
||||||
|
"language": "kotlin",
|
||||||
|
"name": "kotlin"
|
||||||
|
},
|
||||||
|
"datalore": {
|
||||||
|
"version": 1,
|
||||||
|
"computation_mode": "JUPYTER",
|
||||||
|
"package_manager": "pip",
|
||||||
|
"base_environment": "default",
|
||||||
|
"packages": []
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user