We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 1421b34 commit de74835Copy full SHA for de74835
ch08/load_ml100k.py
@@ -22,15 +22,12 @@ def load():
22
reviews = sparse.csc_matrix((values, ij.T)).astype(float)
23
return reviews.toarray()
24
25
-
26
def get_train_test():
27
import numpy as np
+ import random
28
reviews = load()
29
U,M = np.where(reviews)
30
- test_idxs = set()
31
- while len(test_idxs) < len(U)//10:
32
- test_idxs.add(np.random.randint(0,len(U)-1))
33
- test_idxs = np.array(list(test_idxs))
+ test_idxs = np.array(random.sample(range(len(U)), len(U)//10))
34
train = reviews.copy()
35
train[U[test_idxs], M[test_idxs]] = 0
36
0 commit comments