Skip to content

Commit

Permalink
Adding the two stopping criteria for add_knots function as described …
Browse files Browse the repository at this point in the history
…in the paper.
  • Loading branch information
ranibasna committed Oct 8, 2023
1 parent 0b3f7ea commit 1091337
Showing 1 changed file with 89 additions and 116 deletions.
205 changes: 89 additions & 116 deletions R/add_knots.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,170 +74,143 @@


#############
add_knots=function(f,f_v = NULL,knots,L,M=5)
{
add_knots = function(f, f_v=NULL, knots, L, M=5, auto_stop=FALSE, threshold=NULL, stop_method="absolute") {
# check the class of the data
if (is.vector(f)){
# transpose the vector to become a matrix
f = matrix(data = f, nrow = 1, ncol = length(f))
}

nx=dim(f)[2] #The number of points in the grid
#Evaluating AMSE for the input knots.
K=length(knots)


AMSE=vector('numeric',K-1) #Here the values of the average mean squared errors will be kept
APPRERR=vector('numeric',L+1) #Here will be kept the sequence of improved approximation errors
#(in the terms of the average squared L2 norm: ||f1 - hat f1||_2^2+...+||fn - hat fn||_2^2)
#by piecewise constant functions resulting from adding the knots

LE=knots[1:(K-1)] #The open ended left ends of the intervals (so add one to have the closed end)
RE=knots[2:K] #The close ended right ends of the intervals

splits=vector('numeric',K-1) #The new interval-wise optimal split-points
AMSE1=splits #The left-hand side (with respect to corresponding 'splits') values of the average mean square error
AMSE2=splits #The right-hand side (with respect to corresponding 'splits') values of the average mean square error



#First run through all the intervals is to compute all interval-wise split and corresponding 'AMSE1' and 'AMSE2'
for(k in 1:(K-1)) #the loop running through all the intervals at the current knots values
{
#print(k)
#browser()
ff=f[,(knots[k]+1):(knots[k+1]), drop=FALSE]
AMSE[k]=amse(ff) #Here we keep the average mean squared errors for the input knots
newsp=opt_split(ff,AMSE[k],M=M) #Finding optimal split with the given interval
splits[k]=knots[k]+newsp[[1]]
AMSE1[k]=newsp[[2]]
AMSE2[k]=newsp[[3]]
AMSE_v <- c(newsp[[2]],newsp[[3]])
# Get the number of grid points
nx = dim(f)[2]

# Get the number of initial knots
K = length(knots)


# Initialize vectors for AMSE and approximation errors
AMSE = vector('numeric', K-1)
APPRERR = vector('numeric', L+1)
# Get the left and right endpoints of the intervals
LE = knots[1:(K-1)]
RE = knots[2:K]
# Initialize vectors for optimal split points and their corresponding AMSEs
splits = vector('numeric', K-1)
AMSE1 = splits
AMSE2 = splits

# Calculate initial AMSE and optimal split points for each interval
for(k in 1:(K-1)) {
ff = f[,(knots[k]+1):(knots[k+1]), drop=FALSE]
AMSE[k] = amse(ff)
newsp = opt_split(ff, AMSE[k], M=M)
splits[k] = knots[k] + newsp[[1]]
AMSE1[k] = newsp[[2]]
AMSE2[k] = newsp[[3]]
}

# If validation set is provided, initialize validation-related variables
cat("proposed splits is", splits, "\n")
# AMSE_v <- split(f_v, splits)
# cat("first AMSE_v is", AMSE_v, "\n")
# len_v = c(length(na.omit(f_v[1,0:splits])),length(na.omit(f_v[1,(splits+1):nx])))
# #TMSE_v = c()
# #nx_v = dim(f_v)[2]
# nx_v <- length(na.omit(f_v[1,]))
# print(nx_v)
# cat("first len_v is", len_v, "\n" )
# TMSE_v = sum(len_v*AMSE_v)/nx_v
# cat("the first TMSE_V", TMSE_v, "\n")

if(!is.null(f_v)){
nx_v <- length(na.omit(f_v[1,]))
# cat("nx is", nx_v, "\n")
nx_v = length(na.omit(f_v[1,]))
len_v = nx_v
TMSE_v = c()
AMSE_v = c()
}

# Calculate initial approximation error
APPRERR[1] = sum((RE-LE) * AMSE) / nx

#The average approximation error of the functions for the input set of knots.
APPRERR[1]=sum((RE-LE)*AMSE)/nx #Adding 1 is because the between knots interval is LE[i],RE[i]
#so that the number of points in this interval is

#computed splits (potential new knots) and corresponding AMSE1's and AMSE2's
#The full set knots, while updated in the loop below are kept in
FLE=LE #The initial values of the left endpoints
FRE=RE #The initial values of the right endpoints
FAMS=AMSE

Fspl=splits #The splits and corresponding AMS1 and AMS2
FAMS1=AMSE1
FAMS2=AMSE2
# Copy initial knots and AMSE values
FLE = LE
FRE = RE
FAMS = AMSE

for(i in 1:L){
#START of the loop
Fspl = splits
FAMS1 = AMSE1
FAMS2 = AMSE2

if(all(is.na(Fspl))){ #Checking if finding knots can be continued due to the constraint on the number M
prev_AMSE_v = NULL
# Loop to add knots
for(i in 1:L) {
if(all(is.na(Fspl))) { #Checking if finding knots can be continued due to the constraint on the number M
#of points per in between knots intervals
warning(paste0('There are only ', K+i-1 ,' knots. Reduce L or M.' ))

warning(paste0('There are only ', K + i - 1, ' knots. Reduce L or M.'))
break
}
else{
opt2=add_split(f,FLE,FRE,FAMS,FAMS1,FAMS2,Fspl, M=M)
l=opt2$locsp #location of the new knot in the intervals used for the computation, i.e. the knot is
#in (L[l]+1):R[l] so that the new intervals are (L[l]+1):NR[l], (NL[l+1]+1):R[l]
#where NR[l]=NL[l+1] is the new split (knot)
NL=opt2$NLE
NR=opt2$NRE
AMS=opt2$NAMSE

APPRERR[1+i]=sum((NR-NL)*AMS)/nx #The new average sum of the squared norms of errors.
} else {
opt2 = add_split(f, FLE, FRE, FAMS, FAMS1, FAMS2, Fspl, M=M)
l = opt2$locsp #location of the new knot in the intervals used for the computation.
NL = opt2$NLE
NR = opt2$NRE
AMS = opt2$NAMSE

AMS1=opt2$NAMSE1
AMS2=opt2$NAMSE2
spl=opt2$nsplits
APPRERR[1+i] = sum((NR-NL) * AMS) / nx #The new average sum of the squared norms of errors.

AMS1 = opt2$NAMSE1
AMS2 = opt2$NAMSE2
spl = opt2$nsplits

#Updating the complete set of knots by the knew knot.
FLE=append(FLE,NL[l+1],after=l)
FRE=append(FRE,NR[l],after=l-1)
FLE = append(FLE, NL[l+1], after=l)
FRE = append(FRE, NR[l], after=l-1)

#Updating the average MSE
FAMS[l]=AMS[l]
FAMS=append(FAMS,AMS[l+1],after=l)
FAMS[l] = AMS[l]
FAMS = append(FAMS, AMS[l+1], after=l)

#Updating the optimal splits
Fspl[l]=spl[l]
Fspl=append(Fspl,spl[l+1],after=l)
Fspl[l] = spl[l]
Fspl = append(Fspl, spl[l+1], after=l)

#And the corresponding left and right average MSE
FAMS1[l]=AMS1[l]
FAMS1=append(FAMS1,AMS1[l+1],after=l)
FAMS2[l]=AMS2[l]
FAMS2=append(FAMS2,AMS2[l+1],after=l)
FAMS1[l] = AMS1[l]
FAMS1 = append(FAMS1, AMS1[l+1], after=l)
FAMS2[l] = AMS2[l]
FAMS2 = append(FAMS2, AMS2[l+1], after=l)

cat("proposed splits is", spl, "\n")

print("printing the new knot")
print(NL[l+1])
# calculating the amse for the validation set
# cat("nsplits is", spl,"\n")
# cat("l is",l,"\n")

if(!is.null(f_v)){
if(!is.null(f_v) && auto_stop) {
# calculate the current f_v
c_f_v = f_v[,NL[l]:NR[l+1]]
# print(dim(c_f_v))
# print(NL)
# print(NR)
s_v = split(c_f_v,NR[l]-NL[l])
if(NA %in% s_v){
#browser()
c_f_v = f_v[, NL[l]:NR[l+1]]
s_v = split(c_f_v, NR[l]-NL[l])

if(NA %in% s_v) {
next
}
}
# length of the current f_v
l_v = c(length(na.omit(f_v[1,(NL[l]+1):NR[l]])),length(na.omit(f_v[1,(NL[l+1]+1):NR[l+1]])))
l_v = c(length(na.omit(f_v[1, (NL[l]+1):NR[l]])), length(na.omit(f_v[1, (NL[l+1]+1):NR[l+1]])))
AMSE_v[l] = s_v[1]
AMSE_v = append(AMSE_v, s_v[2],after = l)
AMSE_v = append(AMSE_v, s_v[2], after=l)
len_v[l] = l_v[1]
len_v = append(len_v, l_v[2],after = l)
# cat("AMSE_v is", AMSE_v,"\n")
# cat("len_v is", len_v, sum(len_v), "\n")
#TMSE_v = sum(len_v*AMSE_v)/nx_v
TMSE_v = append(TMSE_v, sum(len_v*AMSE_v)/nx_v)
len_v = append(len_v, l_v[2], after=l)
TMSE_v = append(TMSE_v, sum(len_v * AMSE_v) / nx_v)

cat("current TMSE is", TMSE_v, "\n")
# browser()
}
if(!is.null(prev_AMSE_v)) {
abs_diff = abs(TMSE_v[i] - prev_AMSE_v)

if(stop_method == "absolute" && abs_diff < threshold) {
break
} else if(stop_method == "relative" && abs_diff < threshold * abs(TMSE_v[i])) {
break
}
}

prev_AMSE_v = TMSE_v[i]
}
}
#END of the loop.
}
Fknots=c(FLE,FRE[length(FRE)])

Fknots = c(FLE, FRE[length(FRE)])
if(is.null(f_v)){
add_knots=list(Fknots=Fknots,FAMSE=FAMS,APPRERR=APPRERR)
} else{
add_knots=list(Fknots=Fknots,FAMSE=FAMS,APPRERR=APPRERR,TMSE_v=TMSE_v)
add_knots = list(Fknots=Fknots, FAMSE=FAMS, APPRERR=APPRERR)
} else {
add_knots = list(Fknots=Fknots, FAMSE=FAMS, APPRERR=APPRERR, TMSE_v=TMSE_v)
}

return(add_knots)
}
########

0 comments on commit 1091337

Please sign in to comment.