-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstump.py
More file actions
33 lines (28 loc) · 929 Bytes
/
stump.py
File metadata and controls
33 lines (28 loc) · 929 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# This code is supporting material for the book
# Building Machine Learning Systems with Python
# by Willi Richert and Luis Pedro Coelho
# published by PACKT Publishing
#
# It is made available under the MIT License
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
data = load_iris()
features = data['data']
labels = data['target_names'][data['target']]
setosa = (labels == 'setosa')
features = features[~setosa]
labels = labels[~setosa]
virginica = (labels == 'virginica')
best_acc = -1.0
for fi in range(features.shape[1]):
thresh = features[:, fi].copy()
thresh.sort()
for t in thresh:
pred = (features[:, fi] > t)
acc = (pred == virginica).mean()
if acc > best_acc:
best_acc = acc
best_fi = fi
best_t = t
print('Best cut is {0} on feature {1}, which achieves accuracy of {2:.1%}.'.format(
best_t, best_fi, best_acc))