forked from benhamner/CauseEffectPairsChallenge
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_io.py
63 lines (51 loc) · 1.88 KB
/
data_io.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import csv
import json
import numpy as np
import os
import pandas as pd
import pickle
def get_paths():
paths = json.loads(open("SETTINGS.json").read())
for key in paths:
paths[key] = os.path.expandvars(paths[key])
return paths
def parse_dataframe(df):
parse_cell = lambda cell: np.fromstring(cell, dtype=np.float, sep=" ")
df = df.applymap(parse_cell)
return df
def read_train_pairs():
train_path = get_paths()["train_pairs_path"]
return parse_dataframe(pd.read_csv(train_path, index_col="SampleID"))
def read_train_target():
path = get_paths()["train_target_path"]
df = pd.read_csv(path, index_col="SampleID")
df = df.rename(columns = dict(zip(df.columns, ["Target", "Details"])))
return df
def read_train_info():
path = get_paths()["train_info_path"]
return pd.read_csv(path, index_col="SampleID")
def read_valid_pairs():
valid_path = get_paths()["valid_pairs_path"]
return parse_dataframe(pd.read_csv(valid_path, index_col="SampleID"))
def read_valid_info():
path = get_paths()["valid_info_path"]
return pd.read_csv(path, index_col="SampleID")
def read_solution():
solution_path = get_paths()["solution_path"]
return pd.read_csv(solution_path, index_col="SampleID")
def save_model(model):
out_path = get_paths()["model_path"]
pickle.dump(model, open(out_path, "w"))
def load_model():
in_path = get_paths()["model_path"]
return pickle.load(open(in_path))
def read_submission():
submission_path = get_paths()["submission_path"]
return pd.read_csv(submission_path, index_col="SampleID")
def write_submission(predictions):
submission_path = get_paths()["submission_path"]
writer = csv.writer(open(submission_path, "w"), lineterminator="\n")
valid = read_valid_pairs()
rows = [x for x in zip(valid.index, predictions)]
writer.writerow(("SampleID", "Target"))
writer.writerows(rows)