Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
obj-x86_64-linux-gnu
examples
Supervised
CARTTutorial.cpp
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief CART Tutorial Sample Code
6
*
7
* This file is part of the "CART" tutorial.
8
* It requires some toy sample data that comes with the library.
9
*
10
*
11
*
12
* \author K. N. Hansen
13
* \date 2012
14
*
15
*
16
* \par Copyright 1995-2015 Shark Development Team
17
*
18
* <BR><HR>
19
* This file is part of Shark.
20
* <http://image.diku.dk/shark/>
21
*
22
* Shark is free software: you can redistribute it and/or modify
23
* it under the terms of the GNU Lesser General Public License as published
24
* by the Free Software Foundation, either version 3 of the License, or
25
* (at your option) any later version.
26
*
27
* Shark is distributed in the hope that it will be useful,
28
* but WITHOUT ANY WARRANTY; without even the implied warranty of
29
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
30
* GNU Lesser General Public License for more details.
31
*
32
* You should have received a copy of the GNU Lesser General Public License
33
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
34
*
35
*/
36
//===========================================================================
37
38
#include <
shark/Data/Csv.h
>
// importing CSV files
39
#include <
shark/Algorithms/Trainers/CARTTrainer.h
>
// the CART trainer
40
#include <
shark/ObjectiveFunctions/Loss/ZeroOneLoss.h
>
// 0/1 loss for evaluation
41
#include <iostream>
42
43
using namespace
std
;
44
using namespace
shark
;
45
46
47
int
main
() {
48
49
//*****************LOAD AND PREPARE DATA***********************//
50
51
// read data
52
ClassificationDataset
dataTrain;
53
importCSV
(dataTrain,
"data/C.csv"
,
LAST_COLUMN
,
' '
);
54
55
56
//Split the dataset into a training and a test dataset
57
ClassificationDataset
dataTest =
splitAtElement
(dataTrain,311);
58
59
cout <<
"Training set - number of data points: "
<< dataTrain.
numberOfElements
()
60
<<
" number of classes: "
<<
numberOfClasses
(dataTrain)
61
<<
" input dimension: "
<<
inputDimension
(dataTrain) << endl;
62
63
cout <<
"Test set - number of data points: "
<< dataTest.
numberOfElements
()
64
<<
" number of classes: "
<<
numberOfClasses
(dataTest)
65
<<
" input dimension: "
<<
inputDimension
(dataTest) << endl;
66
67
68
//Train the model
69
CARTTrainer
trainer;
70
CARTClassifier<RealVector>
model;
71
trainer.
train
(model, dataTrain);
72
73
// evaluate Random Forest classifier
74
ZeroOneLoss<unsigned int, RealVector>
loss;
75
Data<RealVector>
prediction = model(dataTrain.
inputs
());
76
cout <<
"CART on training set accuracy: "
<< 1. - loss.
eval
(dataTrain.
labels
(), prediction) << endl;
77
78
prediction = model(dataTest.
inputs
());
79
cout <<
"CART on test set accuracy: "
<< 1. - loss.
eval
(dataTest.
labels
(), prediction) << endl;
80
81
}