forked from karpathy/arxiv-sanity-preserver
-
Notifications
You must be signed in to change notification settings - Fork 0
/
buildsvm.py
68 lines (55 loc) · 1.99 KB
/
buildsvm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from sqlite3 import dbapi2 as sqlite3
import cPickle as pickle
import numpy as np
import json
import time
import dateutil.parser
import argparse
from random import shuffle
import re
import os
from sklearn import svm
import utils
DATABASE = 'as.db'
sqldb = sqlite3.connect(DATABASE)
sqldb.row_factory = sqlite3.Row # to return dicts rather than tuples
def query_db(query, args=(), one=False):
"""Queries the database and returns a list of dictionaries."""
cur = sqldb.execute(query, args)
rv = cur.fetchall()
return (rv[0] if rv else None) if one else rv
users = query_db('''select * from user''')
for u in users:
print u
print 'number of users: ', len(users)
def strip_version(idstr):
""" identity function if arxiv id has no version, otherwise strips it. """
parts = idstr.split('v')
return parts[0]
# fetch the tfidf matrix
meta = pickle.load(open("tfidf_meta.p", "rb"))
out = pickle.load(open("tfidf.p", "rb"))
X = out['X']
X = X.todense()
xtoi = { strip_version(x):i for x,i in meta['ptoi'].iteritems() }
user_sim = {}
for ii,u in enumerate(users):
print '%d/%d building an SVM for %s' % (ii, len(users), u['username'].encode('utf-8'))
uid = u['user_id']
lib = query_db('''select * from library where user_id = ?''', [uid])
pids = [x['paper_id'] for x in lib] # raw pids without version
posix = [xtoi[p] for p in pids if p in xtoi]
if not posix:
continue # empty library for this user maybe?
print pids
y = np.zeros(X.shape[0])
for ix in posix:
y[ix] = 1
#__init__(penalty='l2', loss='squared_hinge', dual=True, tol=0.0001, C=1.0, multi_class='ovr', fit_intercept=True, intercept_scaling=1, class_weight=None, verbose=0, random_state=None, max_iter=1000)[source]
clf = svm.LinearSVC(class_weight='auto', verbose=True, max_iter=10000, tol=1e-6, C=1)
clf.fit(X,y)
s = clf.decision_function(X)
sortix = np.argsort(-s)
user_sim[uid] = [strip_version(meta['pids'][ix]) for ix in list(sortix)]
print 'writing user_sim.p'
utils.safe_pickle_dump(user_sim, "user_sim.p")