-
Notifications
You must be signed in to change notification settings - Fork 100
/
util.R
100 lines (92 loc) · 3 KB
/
util.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
require(plyr)
#TODO: set data, submit and r paths to correct directories, if needed
paths = list(data='../data/',
submit='../submissions/',
r='../R/')
sample.submission <- function(){
# Loads the sample submission, which is used in writing predictions
ss <- read.csv(paste0(paths$data, 'sampleSubmission.csv'))
}
raw.train <- function(){
# Loads the training data with correct classes
cls <- c('factor', 'factor', 'Date', 'numeric', 'logical')
train <- read.csv(paste0(paths$data, 'train.csv'),
colClasses=cls)
}
raw.test <- function(){
# Loads the test data with correct column types
cls <- c('factor', 'factor', 'Date', 'logical')
test <- read.csv(paste0(paths$data, 'test.csv'),
colClasses=cls)
}
reload.submission <- function(submit.num){
# Reloads a previously saved submission
#
# args:
# submit.num - the number of the submission
#
# returns:
# the saved submission as a data frame (with Id $ Weekly_Sales fields)
submit.path <- paste0(paths$submit, 'submission', submit.num, '.csv')
read.csv(submit.path)
}
make.average <- function(submissions, wts=NULL){
# Averages previously saved submissions.
#
# args:
# submissions - a vector of submission numbers
# wts - optional vector of weights for submissions
#
# returns:
# a data frame with the weighted average of the Weekly_Sales fields
# from the submissions as its Weekly_Sales field
if(is.null(wts)){
wts <- rep(1, length(submissions))
}
pred <- sample.submission()
for(k in 1:length(submissions)){
sub.k <- reload.submission(submissions[k])
pred.k <- wts[k] * sub.k$Weekly_Sales
pred$Weekly_Sales <- pred$Weekly_Sales + pred.k
}
pred$Weekly_Sales <- pred$Weekly_Sales/sum(wts)
pred
}
write.submission <- function(pred){
# Writes a valid submission to paths$submit.
#
# args:
# pred - a data frame with predictions in the Weekly_Sales field
#
# returns:
# the submission number used
ss <- sample.submission()
subs <- dir(paths$submit)
subs <- grep('submission[0-9]+(.csv)(.zip|.gz)?', subs, value=TRUE)
nums <- gsub('submission','', gsub('(.csv)(.zip|.gz)?','', subs))
if(length(nums) == 0){
submission.number <- 1
}else{
submission.number <- max(as.numeric(nums)) + 1
}
ss$Weekly_Sales <- pred$Weekly_Sales
submit.path = paste0(paths$submit,
'submission',
submission.number,
'.csv')
print(paste('Writing to:', submit.path))
write.csv(ss, file = submit.path, quote=FALSE, row.names=FALSE)
submission.number
}
wmae <- function(pred, test){
# Computes the evaluation metric for Kaggle/Walmart.
#
# args:
# pred - a data frame with predictions in the Weekly_Sales field
# test - a data frame with an IsHoliday field and with the ground truth
# in the Weekly_Sales field
# returns:
# wmae - the weighted mean absolute error
w <- 4*test$IsHoliday + 1
sum(w*abs(pred$Weekly_Sales - test$Weekly_Sales))/sum(w)
}