Skip to content

Commit

Permalink
Merge pull request #1 from tshort/tom-reorg-expand
Browse files Browse the repository at this point in the history
Tom reorg expand
  • Loading branch information
doobwa committed Jul 26, 2012
2 parents 4e27c6f + 84bca7b commit ebe0f57
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 37 deletions.
29 changes: 28 additions & 1 deletion src/dataframe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ done(df::AbstractDataFrame, i) = i > ncol(df)
next(df::AbstractDataFrame, i) = (df[i], i + 1)
## numel(df::AbstractDataFrame) = ncol(df)
isempty(df::AbstractDataFrame) = ncol(df) == 0
# Column groups
set_group(d::DataFrame, newgroup, names) = set_group(d.colindex, newgroup, names)
set_groups(d::DataFrame, gr::Dict{ByteString,Vector{ByteString}}) = set_groups(d.colindex, gr)
get_groups(d::DataFrame) = get_groups(d.colindex)

function insert(df::DataFrame, index::Integer, item, name)
@assert 0 < index <= ncol(df) + 1
Expand All @@ -86,6 +90,14 @@ function insert(df::DataFrame, index::Integer, item, name)
df[[1:index-1, end, index:end-1]]
end

function insert(df::DataFrame, df2::DataFrame)
@assert nrow(df) == nrow(df2) || nrow(df) == 0
for n in colnames(df2)
df[n] = df2[n]
end
df
end

# if we have something else, convert each value in this tuple to a DataVec and pass it in, hoping for the best
DataFrame(vals...) = DataFrame([DataVec(x) for x = vals])
# if we have a matrix, create a tuple of columns and pass that in
Expand Down Expand Up @@ -162,6 +174,12 @@ maxShowLength(dv::AbstractDataVec) = max([length(_string(x)) for x = dv])
function show(io, df::AbstractDataFrame)
## TODO use alignment() like print_matrix in show.jl.
println(io, "$(typeof(df)) $(size(df))")
gr = get_groups(df)
if length(gr) > 0
#print(io, "Column groups: ")
pretty_show(io, gr)
println(io)
end
N = nrow(df)
Nmx = 20 # maximum head and tail lengths
if N <= 2Nmx
Expand Down Expand Up @@ -203,6 +221,11 @@ end

function dump(io::IOStream, x::AbstractDataFrame, n::Int, indent)
println(io, typeof(x), " $(nrow(x)) observations of $(ncol(x)) variables")
gr = get_groups(x)
if length(gr) > 0
pretty_show(io, gr)
println(io)
end
if n > 0
for col in names(x)[1:min(10,end)]
print(io, indent, " ", col, ": ")
Expand Down Expand Up @@ -342,7 +365,11 @@ function csvDataFrame(filename, o::Options)
@check_used o

# combine the columns into a DataFrame and return
DataFrame(cols, columnNames)
if columnNames == []
DataFrame(cols)
else
DataFrame(cols, columnNames)
end
end
csvDataFrame(filename) = csvDataFrame(filename, Options())

Expand Down
76 changes: 44 additions & 32 deletions src/formula.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ end

# Obtain Array of Symbols used in an Expr
function unique_symbols(ex::Expr)
if length(ex.args) == 2
return [unique_symbols(ex.args[2])] # I() case
else
return [unique_symbols(ex.args[2]), unique_symbols(ex.args[3])]
end
[[unique_symbols(a) for a in ex.args[2:end]]...]
end
unique_symbols(ex::Array{Expr,1}) = [unique_symbols(ex[1])]
unique_symbols(ex::Symbol) = [ex]
Expand Down Expand Up @@ -172,42 +168,58 @@ function expand_helper(ex::Symbol, df::DataFrame)
return r
end

function expand_helper(ex::Expr, df::DataFrame)
if length(ex.args) == 2 # e.g. log(x2)
a = with(df, ex)
r = DataFrame()
r[string(ex)] = a
else
r = expand(ex, df)
#
# The main expression to DataFrame expansion function.
# Returns a DataFrame.
#
function expand(ex::Expr, df::DataFrame)
f = eval(ex.args[1])
if method_exists(f, (FormulaExpander, Vector{Any}, DataFrame))
# These are specialized expander functions (+, *, &, etc.)
f(FormulaExpander(), ex.args[2:end], df)
else
# Everything else is called recursively:
println("B", ex, )
expand(with(df, ex), string(ex), df)
end
return r
end
function expand(s::Symbol, df::DataFrame)
expand(with(df, s), string(s), df)
end
function expand(x, name::ByteString, df::DataFrame)
# This is the default for expansion: put it right in to a DataFrame.
DataFrame({x}, [name])
end

# Expand an Expression (with +, &, or *) using the provided DataFrame
# + includes both columns
# & includes the elementwise product of every pair of columns
# * both of the above
function expand(ex, df::DataFrame)
# Recurse on left and right children of provided Expression
a = expand_helper(ex.args[2], df)
b = expand_helper(ex.args[3], df)
#
# Methods for Formula expansion
#
type FormulaExpander; end # This is an indictor type.

# Combine according to formula
if ex.args[1] == :+
return cbind(a,b)
elseif ex.args[1] == :&
return interaction_design_matrix(a,b)
elseif ex.args[1] == :*
return cbind(a, b, interaction_design_matrix(a,b))
else
error("Unknown operation in formula")
end
function +(::FormulaExpander, args::Vector{Any}, df::DataFrame)
d = DataFrame()
for a in args
insert(d, expand(a, df))
end
d
end
function &(::FormulaExpander, args::Vector{Any}, df::DataFrame)
interaction_design_matrix(expand(args[1], df), expand(args[2], df))
end
function *(::FormulaExpander, args::Vector{Any}, df::DataFrame)
d = +(FormulaExpander(), args, df)
# TODO still not right - need all combinations here for a*b*c:
insert(d, interaction_design_matrix(expand(args[1], df), expand(args[2], df)))
d
end

#
# Methods for expansion of specific data types
#

# Expand a PooledDataVector into a matrix of indicators for each dummy variable
# TODO: account for NAs?
function expand(poolcol::PooledDataVec, colname::String)
function expand(poolcol::PooledDataVec, colname::ByteString, df::DataFrame)
newcol = {DataVec([convert(Float64,x)::Float64 for x in (poolcol.refs .== i)]) for i in 2:length(poolcol.pool)}
newcolname = [strcat(colname, ":", x) for x in poolcol.pool[2:length(poolcol.pool)]]
DataFrame(newcol, convert(Vector{ByteString}, newcolname))
Expand Down
53 changes: 49 additions & 4 deletions src/index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
abstract AbstractIndex

type Index <: AbstractIndex # an OrderedDict would be nice here...
lookup::Dict{ByteString,Int} # name => names array position
lookup::Dict{ByteString,Indices} # name => names array position
names::Vector{ByteString}
end
Index{T<:ByteString}(x::Vector{T}) = Index(Dict{ByteString, Int}(tuple(x...), tuple([1:length(x)]...)),
Index{T<:ByteString}(x::Vector{T}) = Index(Dict{ByteString, Indices}(tuple(x...), tuple([1:length(x)]...)),
convert(Vector{ByteString}, x))
Index() = Index(Dict{ByteString,Int}(), ByteString[])
Index() = Index(Dict{ByteString,Indices}(), ByteString[])
length(x::Index) = length(x.names)
names(x::Index) = copy(x.names)
copy(x::Index) = Index(copy(x.lookup), copy(x.names))
Expand All @@ -21,6 +21,10 @@ function names!(x::Index, nm::Vector)
if length(nm) != length(x)
error("lengths don't match.")
end
for i in 1:length(nm)
del(x.lookup, x.names[i])
x.lookup[nm[i]] = i
end
x.names = nm
end

Expand Down Expand Up @@ -51,8 +55,14 @@ function del(x::Index, idx::Integer)
for i in idx+1:length(x.names)
x.lookup[x.names[i]] = i - 1
end
gr = get_groups(x)
del(x.lookup, x.names[idx])
del(x.names, idx)
# fix groups:
for (k,v) in gr
newv = [[has(x, vv) ? vv : ASCIIString[] for vv in v]...]
set_group(x, k, newv)
end
end
function del(x::Index, nm)
if !has(x.lookup, nm)
Expand All @@ -62,7 +72,7 @@ function del(x::Index, nm)
del(x, idx)
end

ref{T<:ByteString}(x::Index, idx::Vector{T}) = convert(Vector{Int}, [x.lookup[i] for i in idx])
ref{T<:ByteString}(x::Index, idx::Vector{T}) = [[x.lookup[i] for i in idx]...]
ref{T<:ByteString}(x::Index, idx::T) = x.lookup[idx]

# fall-throughs, when something other than the index type is passed
Expand All @@ -80,3 +90,38 @@ end
SimpleIndex() = SimpleIndex(0)
length(x::SimpleIndex) = x.length
names(x::SimpleIndex) = nothing

# Chris's idea of namespaces adapted by Harlan for column groups
function set_group(idx::Index, newgroup, names)
if !has(idx, newgroup) || isa(idx.lookup[newgroup], Array)
idx.lookup[newgroup] = [[idx.lookup[nm] for nm in names]...]
end
end
function set_groups(idx::Index, gr::Dict{ByteString,Vector{ByteString}})
for (k,v) in gr
if !has(idx, k)
idx.lookup[k] = [[idx.lookup[nm] for nm in v]...]
end
end
end
function get_groups(idx::Index)
gr = Dict{ByteString,Vector{ByteString}}()
for (k,v) in idx.lookup
if isa(v,Array)
gr[k] = idx.names[v]
end
end
gr
end

# special pretty-printer for groups, which are just Dicts.
function pretty_show(io, gr::Dict{ByteString,Vector{ByteString}})
allkeys = keys(gr)
for k = allkeys
print(io, "$(k): ")
print(io, join(gr[k], ", "))
if k != last(allkeys)
print(io, "; ")
end
end
end
1 change: 1 addition & 0 deletions test/formula.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ d["y"] = [1:4]
d["x1"] = PooledDataVec([5:8])
d["x2"] = [9:12]
d["x3"] = [11:14]
d["x4"] = [12:15]
f = Formula(:(y ~ x1 * (log(x2) + x3)))
mf = model_frame(f, d)
mm = model_matrix(mf)
Expand Down

0 comments on commit ebe0f57

Please sign in to comment.