|
18 | 18 |
|
19 | 19 |
|
20 | 20 | def train_plot(features, labels): |
21 | | - y0,y1 = features[:,2].min()*.9, features[:,2].max()*1.1 |
22 | | - x0,x1 = features[:,0].min()*.9, features[:,0].max()*1.1 |
23 | | - X = np.linspace(x0,x1,100) |
24 | | - Y = np.linspace(y0,y1,100) |
25 | | - X,Y = np.meshgrid(X,Y) |
26 | | - |
27 | | - model = learn_model(1, features[:,(0,2)], np.array(labels)) |
28 | | - C = apply_model(np.vstack([X.ravel(),Y.ravel()]).T, model).reshape(X.shape) |
| 21 | + y0, y1 = features[:, 2].min() * .9, features[:, 2].max() * 1.1 |
| 22 | + x0, x1 = features[:, 0].min() * .9, features[:, 0].max() * 1.1 |
| 23 | + X = np.linspace(x0, x1, 100) |
| 24 | + Y = np.linspace(y0, y1, 100) |
| 25 | + X, Y = np.meshgrid(X, Y) |
| 26 | + |
| 27 | + model = learn_model(1, features[:, (0, 2)], np.array(labels)) |
| 28 | + C = apply_model( |
| 29 | + np.vstack([X.ravel(), Y.ravel()]).T, model).reshape(X.shape) |
29 | 30 | if COLOUR_FIGURE: |
30 | | - cmap = ListedColormap([(1.,.6,.6),(.6,1.,.6),(.6,.6,1.)]) |
| 31 | + cmap = ListedColormap([(1., .6, .6), (.6, 1., .6), (.6, .6, 1.)]) |
31 | 32 | else: |
32 | | - cmap = ListedColormap([(1.,1.,1.),(.2,.2,.2),(.6,.6,.6)]) |
33 | | - plt.xlim(x0,x1) |
34 | | - plt.ylim(y0,y1) |
| 33 | + cmap = ListedColormap([(1., 1., 1.), (.2, .2, .2), (.6, .6, .6)]) |
| 34 | + plt.xlim(x0, x1) |
| 35 | + plt.ylim(y0, y1) |
35 | 36 | plt.xlabel(feature_names[0]) |
36 | 37 | plt.ylabel(feature_names[2]) |
37 | | - plt.pcolormesh(X,Y,C, cmap=cmap) |
| 38 | + plt.pcolormesh(X, Y, C, cmap=cmap) |
38 | 39 | if COLOUR_FIGURE: |
39 | | - cmap = ListedColormap([(1.,.0,.0),(.0,1.,.0),(.0,.0,1.)]) |
40 | | - plt.scatter(features[:,0], features[:,2], c=labels, cmap=cmap) |
| 40 | + cmap = ListedColormap([(1., .0, .0), (.0, 1., .0), (.0, .0, 1.)]) |
| 41 | + plt.scatter(features[:, 0], features[:, 2], c=labels, cmap=cmap) |
41 | 42 | else: |
42 | | - for lab,ma in zip(range(3), "Do^"): |
43 | | - plt.plot(features[labels == lab,0], features[labels == lab,2], ma, c=(1.,1.,1.)) |
| 43 | + for lab, ma in zip(range(3), "Do^"): |
| 44 | + plt.plot(features[labels == lab, 0], features[ |
| 45 | + labels == lab, 2], ma, c=(1., 1., 1.)) |
44 | 46 |
|
45 | 47 |
|
46 | | -features,labels = load_dataset('seeds') |
| 48 | +features, labels = load_dataset('seeds') |
47 | 49 | names = sorted(set(labels)) |
48 | 50 | labels = np.array([names.index(ell) for ell in labels]) |
49 | 51 |
|
|
0 commit comments