Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
obj-x86_64-linux-gnu
examples
Supervised
VersatileClassificationTutorial-NN.cpp
Go to the documentation of this file.
1
2
#include <
shark/Data/Dataset.h
>
3
#include <
shark/Data/Csv.h
>
4
#include <
shark/ObjectiveFunctions/Loss/ZeroOneLoss.h
>
5
6
#include <
shark/Models/Trees/KDTree.h
>
7
#include <
shark/Models/NearestNeighborClassifier.h
>
8
#include <
shark/Algorithms/NearestNeighbors/TreeNearestNeighbors.h
>
9
10
11
using namespace
shark
;
12
13
int
main
()
14
{
15
// Load data, use 70% for training and 30% for testing.
16
// The path is hard coded; make sure to invoke the executable
17
// from a place where the data file can be found. It is located
18
// under [shark]/examples/Supervised/data.
19
ClassificationDataset
traindata, testdata;
20
importCSV
(traindata,
"data/quickstartData.csv"
,
LAST_COLUMN
,
' '
);
21
testdata =
splitAtElement
(traindata, 70 * traindata.
numberOfElements
() / 100);
22
23
unsigned
int
k = 3;
// number of neighbors
24
KDTree<RealVector>
tree(traindata.
inputs
());
25
TreeNearestNeighbors<RealVector, unsigned int>
algorithm(traindata, &tree);
26
NearestNeighborClassifier<RealVector>
model(&algorithm, k);
27
28
Data<unsigned int>
prediction = model(testdata.
inputs
());
29
30
ZeroOneLoss<unsigned int>
loss;
31
double
error_rate = loss(testdata.
labels
(), prediction);
32
33
std::cout <<
"model: "
<< model.
name
() << std::endl
34
<<
"test error rate: "
<< error_rate << std::endl;
35
}