Skip to content

Commit de74835

Browse files
committed
RFCT Simplify sampling code
1 parent 1421b34 commit de74835

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

ch08/load_ml100k.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,12 @@ def load():
2222
reviews = sparse.csc_matrix((values, ij.T)).astype(float)
2323
return reviews.toarray()
2424

25-
2625
def get_train_test():
2726
import numpy as np
27+
import random
2828
reviews = load()
2929
U,M = np.where(reviews)
30-
test_idxs = set()
31-
while len(test_idxs) < len(U)//10:
32-
test_idxs.add(np.random.randint(0,len(U)-1))
33-
test_idxs = np.array(list(test_idxs))
30+
test_idxs = np.array(random.sample(range(len(U)), len(U)//10))
3431
train = reviews.copy()
3532
train[U[test_idxs], M[test_idxs]] = 0
3633

0 commit comments

Comments
 (0)