Skip to content

Commit

Permalink
tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jul 25, 2023
1 parent 35bc6e1 commit b3d9a50
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions vision/convmixer_cifar10/convmixer.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
using Flux, MLDatasets
using Flux: onehotbatch, onecold, DataLoader, flatten, OptimiserChain
using BSON:@save,@load
Flux._old_to_new(rule::ClipNorm) = Flux.Optimisers.ClipNorm(rule.thresh) # wrong in Flux 0.13.9
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"

# wrong in Flux 0.13.9
Flux._old_to_new(rule::ClipNorm) = Flux.Optimisers.ClipNorm(rule.thresh)

# Also, quick test of train(epochs=10, images=128) shows increasing loss, not sure why.

function ConvMixer(in_channels, kernel_size, patch_size, dim, depth, N_classes)
f = Chain(
Expand All @@ -22,19 +26,19 @@ function ConvMixer(in_channels, kernel_size, patch_size, dim, depth, N_classes)
return f
end

function get_data(batchsize; dataset = MLDatasets.CIFAR10, idxs = nothing)
"""
idxs=nothing gives the full dataset, otherwise (for testing purposes) only the 1:idxs elements of the train set are given.
"""
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"
"""
By default gives the full dataset, keyword images gives (for testing purposes)
only the 1:images elements of the train set.
"""
function get_data(batchsize; dataset = MLDatasets.CIFAR10, images = :)

# Loading Dataset
if idxs===nothing
if images === (:)
xtrain, ytrain = dataset(:train)[:]
xtest, ytest = dataset(:test)[:]
else
xtrain, ytrain = dataset(:train)[1:idxs]
xtest, ytest = dataset(:test)[1:Int(idxs/10)]
xtrain, ytrain = dataset(:train)[1:images]
xtest, ytest = dataset(:test)[1:(images÷10)]
end

# Reshape Data to comply to Julia's (width, height, channels, batch_size) convention in case there are only 1 channel (eg MNIST)
Expand Down Expand Up @@ -74,10 +78,10 @@ function create_loss_function(dataloader, device)
end


function train(n_epochs=100)
function train(; epochs=100, images=:)

#params: warning, the training can be long with these params
train_loader, test_loader = get_data(128)
train_loader, test_loader = get_data(128; images)
η = 3f-4
in_channel = 3
patch_size = 2
Expand All @@ -88,8 +92,8 @@ function train(n_epochs=100)
use_cuda = true

#logging the losses
train_save = zeros(n_epochs, 2)
test_save = zeros(n_epochs, 2)
train_save = zeros(epochs, 2)
test_save = zeros(epochs, 2)

if use_cuda
device = gpu
Expand All @@ -111,11 +115,11 @@ function train(n_epochs=100)
)
state = Flux.setup(opt, model)

for epoch in 1:n_epochs
for epoch in 1:epochs
for (x,y) in train_loader
x,y = x|>device, y|>device
grads = gradient(m->Flux.logitcrossentropy(m(x), y, agg=sum), model)
Flux.Optimise.update!(state, model, grads[1])
Flux.update!(state, model, grads[1])
end

#logging
Expand Down

0 comments on commit b3d9a50

Please sign in to comment.