Skip to content

Commit e885d6a

Browse files
committed
Fixing train/test bug in charts; updating for 2nd edition of the book.
1 parent 9e9da69 commit e885d6a

File tree

5 files changed

+68
-40
lines changed

5 files changed

+68
-40
lines changed

ch05/README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
Chapter 5 - Classification - Detecting Poor Answers
22
===================================================
33

4-
The book chapter is based on StackExchange's data blob from August 2012:
5-
[http://www.clearbits.net/get/2076-aug-2012.torrent](http://www.clearbits.net/get/2076-aug-2012.torrent)
4+
The book chapter is based on StackExchange's data blob from August 2012 for the first edition.
65

7-
After publishing the book, StackExchange stayed as awesome as it always has been and released an updated version:
8-
[https://archive.org/download/stackexchange/stackexchange_archive.torrent](https://archive.org/download/stackexchange/stackexchange_archive.torrent)
6+
After publishing the book, StackExchange released the May 2014 version at
7+
[https://archive.org/download/stackexchange/stackexchange_archive.torrent](https://archive.org/download/stackexchange/stackexchange_archive.torrent).
98

109
Note that using the latest version, you will get slightly different results.
1110

ch05/classify.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@
3030

3131
import nltk
3232

33-
# splitting questions into train (70%) and test(30%) and then take their
34-
# answers
35-
all_posts = list(meta.keys())
36-
all_questions = [q for q, v in meta.items() if v['ParentId'] == -1]
37-
all_answers = [q for q, v in meta.items() if v['ParentId'] != -1] # [:500]
33+
# The sorting below is only to ensure reproducable numbers. Further down
34+
# we will occasionally skip a fold when it contains instances of only
35+
# one label. The two lines below ensure that the behavior is exactly the
36+
# same for different runs.
37+
all_questions = sorted([q for q, v in meta.items() if v['ParentId'] == -1])
38+
all_answers = sorted([q for q, v in meta.items() if v['ParentId'] != -1])
3839

3940
feature_names = np.array((
4041
'NumTextTokens',
@@ -47,14 +48,6 @@
4748
'NumImages'
4849
))
4950

50-
# activate the following for reduced feature space
51-
"""
52-
feature_names = np.array((
53-
'NumTextTokens',
54-
'LinkCount',
55-
))
56-
"""
57-
5851

5952
def prepare_sent_features():
6053
for pid, text in fetch_posts(chosen, with_index=True):
@@ -80,17 +73,26 @@ def get_features(aid):
8073
return tuple(meta[aid][fn] for fn in feature_names)
8174

8275
qa_X = np.asarray([get_features(aid) for aid in all_answers])
83-
# Score > 0 tests => positive class is good answer
84-
# Score <= 0 tests => positive class is poor answer
85-
qa_Y = np.asarray([meta[aid]['Score'] > 0 for aid in all_answers])
76+
8677
classifying_answer = "good"
78+
#classifying_answer = "poor"
79+
80+
if classifying_answer == "good":
81+
# Score > 0 tests => positive class is good answer
82+
qa_Y = np.asarray([meta[aid]['Score'] > 0 for aid in all_answers])
83+
elif classifying_answer == "poor":
84+
# Score <= 0 tests => positive class is poor answer
85+
qa_Y = np.asarray([meta[aid]['Score'] <= 0 for aid in all_answers])
86+
else:
87+
raise Exception("classifying_answer='%s' is not supported" %
88+
classifying_answer)
8789

8890
for idx, feat in enumerate(feature_names):
8991
plot_feat_hist([(qa_X[:, idx], feat)])
90-
"""
91-
plot_feat_hist([(qa_X[:, idx], feature_names[idx]) for idx in [1,0]], 'feat_hist_two.png')
92-
plot_feat_hist([(qa_X[:, idx], feature_names[idx]) for idx in [3,4,5,6]], 'feat_hist_four.png')
93-
"""
92+
93+
#plot_feat_hist([(qa_X[:, idx], feature_names[idx]) for idx in [1,0]], 'feat_hist_two.png')
94+
#plot_feat_hist([(qa_X[:, idx], feature_names[idx]) for idx in [3,4,5,6]], 'feat_hist_four.png')
95+
9496
avg_scores_summary = []
9597

9698

@@ -115,10 +117,16 @@ def measure(clf_class, parameters, name, data_size=None, plot=False):
115117
pr_scores = []
116118
precisions, recalls, thresholds = [], [], []
117119

118-
for train, test in cv:
120+
for fold_idx, (train, test) in enumerate(cv):
119121
X_train, y_train = X[train], Y[train]
120122
X_test, y_test = X[test], Y[test]
121123

124+
only_one_class_in_train = len(set(y_train)) == 1
125+
only_one_class_in_test = len(set(y_test)) == 1
126+
if only_one_class_in_train or only_one_class_in_test:
127+
# this would pose problems later on
128+
continue
129+
122130
clf = clf_class(**parameters)
123131

124132
clf.fit(X_train, y_train)
@@ -145,12 +153,20 @@ def measure(clf_class, parameters, name, data_size=None, plot=False):
145153
precisions.append(precision)
146154
recalls.append(recall)
147155
thresholds.append(pr_thresholds)
156+
157+
# This threshold is determined at the end of the chapter 5,
158+
# where we find conditions such that precision is in the area of
159+
# about 80%. With it we trade off recall for precision.
160+
threshold_for_detecting_good_answers = 0.59
161+
162+
print("Clone #%i" % fold_idx)
148163
print(classification_report(y_test, proba[:, label_idx] >
149-
0.63, target_names=['not accepted', 'accepted']))
164+
threshold_for_detecting_good_answers, target_names=['not accepted', 'accepted']))
150165

151166
# get medium clone
152167
scores_to_sort = pr_scores # roc_scores
153168
medium = np.argsort(scores_to_sort)[len(scores_to_sort) / 2]
169+
print("Medium clone is #%i" % medium)
154170

155171
if plot:
156172
#plot_roc(roc_scores[medium], name, fprs[medium], tprs[medium])
@@ -178,6 +194,7 @@ def measure(clf_class, parameters, name, data_size=None, plot=False):
178194

179195

180196
def bias_variance_analysis(clf_class, parameters, name):
197+
#import ipdb;ipdb.set_trace()
181198
data_sizes = np.arange(60, 2000, 4)
182199

183200
train_errors = []
@@ -208,13 +225,16 @@ def k_complexity_analysis(clf_class, parameters):
208225

209226
plot_k_complexity(ks, train_errors, test_errors)
210227

211-
for k in [5]: # [5, 10, 40, 90]:
228+
for k in [5]:
229+
# for k in [5, 10, 40]:
230+
#measure(neighbors.KNeighborsClassifier, {'n_neighbors': k}, "%iNN" % k)
212231
bias_variance_analysis(neighbors.KNeighborsClassifier, {
213232
'n_neighbors': k}, "%iNN" % k)
214233
k_complexity_analysis(neighbors.KNeighborsClassifier, {'n_neighbors': k})
215234

216235
from sklearn.linear_model import LogisticRegression
217-
for C in [0.1]: # [0.01, 0.1, 1.0, 10.0]:
236+
for C in [0.1]:
237+
# for C in [0.01, 0.1, 1.0, 10.0]:
218238
name = "LogReg C=%.2f" % C
219239
bias_variance_analysis(LogisticRegression, {'penalty': 'l2', 'C': C}, name)
220240
measure(LogisticRegression, {'penalty': 'l2', 'C': C}, name, plot=True)

ch05/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import os
99

10-
DATA_DIR = "data" # put your posts-2011-12.xml into this directory
10+
DATA_DIR = "data" # put your posts-2012.xml into this directory
1111
CHART_DIR = "charts"
1212

1313
filtered = os.path.join(DATA_DIR, "filtered.tsv")

ch05/so_xml_to_tsv.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# to a question that has been asked in 2011 or 2012.
1111
#
1212

13+
import sys
1314
import os
1415
import re
1516
try:
@@ -24,9 +25,13 @@
2425

2526
from data import DATA_DIR
2627

27-
filename = os.path.join(DATA_DIR, "posts-2011-12.xml")
28+
#filename = os.path.join(DATA_DIR, "posts-2011-12.xml")
29+
filename = os.path.join(DATA_DIR, "posts-2012.xml")
30+
print("Reading from xml %s" % filename)
2831
filename_filtered = os.path.join(DATA_DIR, "filtered.tsv")
32+
print("Filtered: %s" % filename_filtered)
2933
filename_filtered_meta = os.path.join(DATA_DIR, "filtered-meta.json")
34+
print("Meta: %s" % filename_filtered_meta)
3035

3136
q_creation = {} # creation datetimes of questions
3237
q_accepted = {} # id of accepted answer
@@ -77,22 +82,26 @@ def filter_html(s):
7782
num_questions = 0
7883
num_answers = 0
7984

80-
from itertools import imap
85+
if sys.version_info.major < 3:
86+
# Python 2, map() returns a list, which will lead to out of memory errors.
87+
# The following import ensures that the script behaves like being executed
88+
# with Python 3.
89+
from itertools import imap as map
8190

8291

8392
def parsexml(filename):
8493
global num_questions, num_answers
8594

8695
counter = 0
8796

88-
it = imap(itemgetter(1),
89-
iter(etree.iterparse(filename, events=('start',))))
97+
it = map(itemgetter(1),
98+
iter(etree.iterparse(filename, events=('start',))))
9099

91100
root = next(it) # get posts element
92101

93102
for elem in it:
94103
if counter % 100000 == 0:
95-
print(counter)
104+
print("Processed %i <row/> elements" % counter)
96105

97106
counter += 1
98107

ch05/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,8 @@ def plot_feat_hist(data_name_list, filename=None):
171171
assert filename is not None
172172

173173
pylab.figure(num=None, figsize=(8, 6))
174-
num_rows = 1 + (len(data_name_list) - 1) / 2
175-
num_cols = 1 if len(data_name_list) == 1 else 2
174+
num_rows = int(1 + (len(data_name_list) - 1) / 2)
175+
num_cols = int(1 if len(data_name_list) == 1 else 2)
176176
pylab.figure(figsize=(5 * num_cols, 4 * num_rows))
177177

178178
for i in range(num_rows):
@@ -191,7 +191,7 @@ def plot_feat_hist(data_name_list, filename=None):
191191
else:
192192
bins = max_val
193193
n, bins, patches = pylab.hist(
194-
x, bins=bins, normed=1, facecolor='blue', alpha=0.75)
194+
x, bins=bins, normed=1, alpha=0.75)
195195

196196
pylab.grid(True)
197197

@@ -209,7 +209,7 @@ def plot_bias_variance(data_sizes, train_errors, test_errors, name, title):
209209
pylab.title("Bias-Variance for '%s'" % name)
210210
pylab.plot(
211211
data_sizes, test_errors, "--", data_sizes, train_errors, "b-", lw=1)
212-
pylab.legend(["train error", "test error"], loc="upper right")
212+
pylab.legend(["test error", "train error"], loc="upper right")
213213
pylab.grid(True, linestyle='-', color='0.75')
214214
pylab.savefig(
215215
os.path.join(CHART_DIR, "bv_" + name.replace(" ", "_") + ".png"), bbox_inches="tight")
@@ -220,10 +220,10 @@ def plot_k_complexity(ks, train_errors, test_errors):
220220
pylab.ylim([0.0, 1.0])
221221
pylab.xlabel('k')
222222
pylab.ylabel('Error')
223-
pylab.title('Errors for for different values of k')
223+
pylab.title('Errors for for different values of $k$')
224224
pylab.plot(
225225
ks, test_errors, "--", ks, train_errors, "-", lw=1)
226-
pylab.legend(["train error", "test error"], loc="upper right")
226+
pylab.legend(["test error", "train error"], loc="upper right")
227227
pylab.grid(True, linestyle='-', color='0.75')
228228
pylab.savefig(
229229
os.path.join(CHART_DIR, "kcomplexity.png"), bbox_inches="tight")

0 commit comments

Comments
 (0)