Skip to content

Latest commit

 

History

History
401 lines (330 loc) · 11.2 KB

README.md

File metadata and controls

401 lines (330 loc) · 11.2 KB

vim-autograd Examples

Here are some examples of using vim-autograd. Each example can be run like :source examples/*.vim in Vim with vim-autograd installed.

Simplest Differentiation

In this example, do define-by-run style automatic differentiation using a PyTorch-like interface.

First, apply differentiable operations on Tensor objects generated by the autograd#tensor function with functions or methods. Then, when the backward method is called starting from the output Tensor object, the gradient is computed by backward propagation following the computational graph generated during forward propagation in the reverse direction. Finally, the gradient propagated from the output is accumulated in the .grad attribute of the input Tensor object.

examples/basic.vim

function! s:f(x) abort
  " y = x^5 - 2x^3
  let y = autograd#sub(a:x.p(5), a:x.p(3).m(2))
  return y
endfunction

function! s:main() abort
  let x = autograd#tensor(2.0)

  let y = s:f(x)
  call y.backward()

  echo x.grad.data

  let x.name = 'x'
  let y.name = 'y'
  call autograd#dump_graph(y, '.autograd/example1.png')
endfunction

call s:main()

Output

[56.0]

Generated computational graph

To output a graph as an image file, graphviz must be installed. On Ubuntu, this can be done with the following command.

$ sudo apt install graphviz

If not installed, only the DOT language source code is generated.

Higher-order Differentiation

Because vim-autograd supports double-backprop feature, you can do higher-order differentiation by further differentiating the first-order derivative.

.backward() and autograd#grad() are available for differentiation. The former requires resetting the gradient as it accumulates in the input variable in higher-order differentiation, while the latter does not pollute the gradient of the input variable. The following code does almost the same thing.

.backward()

let y = f(x)
call y.backward(1)
let gx1 = x.grad

call x.cleargrad()
call gx1.backward()
let gx2 = x.grad

autograd#grad()

let y = f(x)
let gx1 = autograd#grad(y, x, 1)
let gx2 = autograd#grad(gx1, x, 1)

To enable double-backprop, the first argument of backward() or the third argument of autograd#grad() must be True(1).

The following is an example of finding the third-order derivative.

examples/higher-order.vim

function! s:f(x) abort
  " y = x^5 - 2x^3 + 4x^2 + 6x + 5
  let t1 = a:x.p(5)
  let t2 = a:x.p(3).m(2).n()
  let t3 = a:x.p(2).m(4)
  let t4 = a:x.m(6)
  let t5 = 5
  let y = t1.a(t2).a(t3).a(t4).a(t5)
  return y
endfunction

function! s:main() abort
  let x = autograd#tensor(2.0)
  let y = s:f(x)
  echo 'y  :' y.data

  " gx1 = 5x^4 - 6x^2 + 8x + 6
  let gx1 = autograd#grad(y, x, 1)
  echo 'gx1:' gx1.data

  " gx2 = 20x^3 - 12x + 8
  let gx2 = autograd#grad(gx1, x, 1)
  echo 'gx2:' gx2.data

  " gx3 = 60x^2 - 12
  call gx2.backward(1)
  echo 'gx3:' x.grad.data
endfunction

call s:main()

Output

y  : [49.0]
gx1: [78.0]
gx2: [144.0]
gx3: [228.0]

Classification using Deep Learning

Since vim-autograd can find the gradient, it is possible to use the gradient descent method for deep learning.
Here we use the wine classification dataset, a public toy dataset provided by UCI, to classify three types of wine from a 13-dimensional vector.

Preprocess Dataset

First, we standardize this data set and divide into training set and test set.

examples/wine-classify.vim

function! s:get_wine_dataset() abort
  " This refers to the following public toy dataset.
  " https://archive.ics.uci.edu/ml/datasets/Wine
  let dataset = map(readfile('.autograd/wine.data'),
    \ "map(split(v:val, ','), 'str2float(v:val)')")

  let N = len(dataset)

  " average
  let means = repeat([0.0], 14)
  for data in dataset
    for l:i in range(1, 13)
      let means[l:i] += data[l:i]
    endfor
  endfor
  call map(means, 'v:val / N')

  " standard deviation
  let stds = repeat([0.0], 14)
  for data in dataset
    for l:i in range(1, 13)
      let stds[l:i] += pow(data[l:i] - means[l:i], 2)
    endfor
  endfor
  call map(stds, 'sqrt(v:val / N)')

  " standardization
  for data in dataset
    for l:i in range(1, 13)
      let data[l:i] = (data[l:i] - means[l:i]) / stds[l:i]
    endfor
  endfor

  " split the dataset into train and test.
  let train_x = []
  let train_t = []
  let test_x = []
  let test_t = []
  let test_num_per_class = 10
  for l:i in range(3)
    let class_split = autograd#shuffle(
      \ filter(deepcopy(dataset), 'v:val[0] == l:i + 1'))

    let train_split = class_split[:-test_num_per_class - 1]
    let test_split = class_split[-test_num_per_class:]

    let train_x += mapnew(train_split, 'v:val[1:]')
    let train_t += mapnew(train_split, "map(v:val[:0], 'v:val - 1')")
    let test_x += mapnew(test_split, 'v:val[1:]')
    let test_t += mapnew(test_split, "map(v:val[:0], 'v:val - 1')")
  endfor
  return {
    \ 'train': [train_x, train_t],
    \ 'test': [test_x, test_t],
    \ 'insize': len(train_x[0]),
    \ 'nclass': 3,
    \ 'mean': means[1:],
    \ 'std': stds[1:]
    \ }
endfunction

Build Neural Network

A multi-layer network is then constructed using fully connected layers. However, we need to use the differentiable functions provided by vim-autograd.

examples/wine-classify.vim

function!  s:linear(x, W, b={}) abort
  let t = autograd#matmul(a:x, a:W)
  return empty(a:b) ? t : autograd#add(t, a:b)
endfunction

function! s:relu(x) abort
  return autograd#maximum(a:x, 0.0)
endfunction

function! s:softmax(x) abort
  let y = autograd#exp(a:x.s(autograd#max(a:x)))
  let s = autograd#sum(y, 1, 1)
  return autograd#div(y, s)
endfunction

function! s:cross_entropy_loss(y, t)
  let loss = autograd#mul(a:t, autograd#log(a:y))
  let batch_size = loss.shape[0]
  return autograd#div(autograd#sum(loss), batch_size).n()
endfunction

let s:MLP = {'params': []}
function! s:MLP(in_size, ...) abort
  let l:mlp = deepcopy(s:MLP)

  let std = sqrt(2.0 / a:in_size)
  let l:W = autograd#normal(0, std, [a:in_size, a:1])
  let l:b = autograd#zeros([a:1])
  let l:W.name = 'W0'
  let l:b.name = 'b0'
  let l:mlp.params += [l:W, l:b]

  for l:i in range(a:0 - 1)
    let std = sqrt(2.0 / a:000[l:i])
    let l:W = autograd#normal(0, std, [a:000[l:i], a:000[l:i + 1]])
    let l:W.name = 'W' . string(l:i + 1)
    let l:b = autograd#zeros([a:000[l:i + 1]])
    let l:b.name = 'b' . string(l:i + 1)
    let l:mlp.params += [l:W, l:b]
  endfor
  return l:mlp
endfunction

function! s:MLP.forward(x) abort
  let y = s:linear(a:x, self.params[0], self.params[1])
  for l:i in range(2, len(self.params) - 1, 2)
    let y = s:relu(y)
    let y = s:linear(y, self.params[l:i], self.params[l:i + 1])
  endfor
  let y = s:softmax(y)
  return y
endfunction

Prepare Optimizer

SGD with momentum, weight decay, and gradient clipping can be implemented as follows.

examples/wine-classify.vim

let s:SGD = {
  \ 'vs': {},
  \ 'momentum': 0.9,
  \ 'lr': 0.01,
  \ 'weight_decay': 0.0,
  \ 'grad_clip': -1
  \ }
function! s:SGD.each_update(param) abort
  if self.weight_decay != 0
    call autograd#elementwise(
      \ [a:param.grad, a:param],
      \ {g, p -> g + self.weight_decay * p}, a:param.grad)
  endif

  if self.momentum == 0
    return autograd#elementwise(
      \ [a:param, a:param.grad], {p, g -> p - g * self.lr}, a:param)
  endif

  if !self.vs->has_key(a:param.id)
    let self.vs[a:param.id] = autograd#zeros_like(a:param)
  endif

  let v = self.vs[a:param.id]

  let v = autograd#sub(v.m(self.momentum), a:param.grad.m(self.lr))
  let self.vs[a:param.id] = v
  return autograd#elementwise([a:param, v], {a, b -> a + b}, a:param)
endfunction

function! s:SGD.step(params) abort
  " gradients clipping
  if self.grad_clip > 0
    let grads_norm = 0.0
    for param in a:params
      let grads_norm = autograd#sum(param.grad.p(2))
    endfor
    let grads_norm = autograd#sqrt(grads_norm).data[0]
    let clip_rate = self.grad_clip / (grads_norm + 0.000001)
    if clip_rate < 1.0
      for param in a:params
        let param.grad = param.grad.m(clip_rate)
      endfor
    endif
  endif

  call map(a:params, 'self.each_update(v:val)')
endfunction

function! s:SGD(...) abort
  let l:optim = deepcopy(s:SGD)
  let l:optim.lr = get(a:, 1, 0.01)
  let l:optim.momentum = get(a:, 2, 0.9)
  let l:optim.weight_decay = get(a:, 3, 0.0)
  let l:optim.grad_clip = get(a:, 4, -1)
  return l:optim
endfunction

Training

With the above basic layers and optimizers, the training can be described like a general deep learning framework (e.g. PyTorch, Chainer).

examples/wine-classify.vim

function! s:main() abort
  call autograd#manual_seed(42)

  let data = s:get_wine_dataset()
  let model = s:MLP(data['insize'], 100, data['nclass'])
  let optimizer = s:SGD(0.1, 0.9, 0.0001, 10.0)

  " train
  let max_epoch = 50
  let batch_size = 16
  let train_data_num = len(data['train'][0])
  let each_iteration = float2nr(ceil(1.0 * train_data_num / batch_size))

  let logs = []
  for epoch in range(max_epoch)
    let indexes = autograd#shuffle(range(train_data_num))
    let epoch_loss = 0
    for l:i in range(each_iteration)
      let x = []
      let t = []
      for index in indexes[l:i * batch_size:(l:i + 1) * batch_size - 1]
        call add(x, data['train'][0][index])

        let onehot = repeat([0.0], data['nclass'])
        let onehot[float2nr(data['train'][1][index][0])] = 1.0
        call add(t, onehot)
      endfor

      let y = model.forward(x)
      let loss = s:cross_entropy_loss(y, t)
      " call autograd#dump_graph(loss, '.autograd/loss.png')

      for param in model.params
        call param.cleargrad()
      endfor
      call loss.backward()

      call optimizer.step(model.params)
      let l:epoch_loss += loss.data[0]
    endfor

    let l:epoch_loss /= each_iteration

    " logging
    call add(logs, epoch . ', ' . l:epoch_loss)
    call writefile(logs, '.autograd/train.log')
  endfor

  " evaluate
  let ng = autograd#no_grad()
  let accuracy = 0.0
  for l:i in range(len(data['test'][0]))
    let pred = model.forward([data['test'][0][l:i]])

    " argmax
    let class_idx = index(pred.data, autograd#max(pred).data[0])
    let accuracy += class_idx == data['test'][1][l:i][0]
  endfor
  call ng.end()

  echomsg 'accuracy: ' . accuracy / len(data['test'][1])
endfunction

When s:main() is executed, the loss is reduced as follows, and training is completed in a few minutes.

0, 0.379945
1, 0.094833
2, 0.029978
3, 0.002876
4, 0.027007
5, 0.065495
6, 0.020479
7, 0.01342
8, 0.046886
9, 0.042945
...

Output

accuracy: 0.966667

The computational graph generated is shown below.