Skip to content

Commit

Permalink
Merge pull request #8 from EiffL/Seattle2019
Browse files Browse the repository at this point in the history
Adds missing files
  • Loading branch information
EiffL authored Oct 15, 2019
2 parents 77ec989 + d682153 commit 3a9ead8
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 0 deletions.
272 changes: 272 additions & 0 deletions Seattle2019/dgm_prior.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
<!DOCTYPE html>
<html>
<head>
<title>Generative Models as Priors for Inverse Problems</title>
<style>
body {
height: 500px;
width: 500px;
}
/* #animation {
position: absolute;
top: 0px;
left: 0px;
background: #000;
} body {
text-align: center;
}

#mynetwork {
height: 500px;
} */
</style>

<!-- Import TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.1.0/dist/tf.min.js"></script>
<script src="https://d3js.org/d3.v5.min.js"></script>
</head>

<body>
<canvas id="animation" height="500" width="500"></canvas>
<script>

(async function animation() {
// This function is closely modeled on http://bl.ocks.org/newby-jay/767c5ffdbbe43b65902f
const model = await tf.loadGraphModel('models/js/export3/model.json');
const grads = tf.grad(x => model.predict(x));
const vfunc = (x,y) => {
c = tf.concat([tf.reshape(x, [-1,1] ), tf.reshape(y, [-1,1])], axis=1);
gr = grads(c);
gn = tf.sum(tf.mul(gr,gr), axis=1, keepDims=true);
gr = tf.mul(tf.div(gr, gn), tf.clipByValue(gn,0,100));

return tf.split(gr, 2, axis=1);
};


// vector field data
var dt = 0.003,
X0 = [], Y0 = [], // to store initial starting locations
X = [], Y = [], // to store current point for each curve
Xd=[], Yd=[],
xb = 4, yb = 3;
var sigma=0.2;
var X1=0.5, Y1=0.5, X2=0, Y2=0;
var XC=1, YC=1;
var width = 500, height = 500;

// First draw the modelled density in the background
var N = 128
var xd = d3.range(N).map(
function (i) {
return -1.5 + xb*i/N;
}),
yd = d3.range(N).map(
function (i) {
return -1 + yb*i/N;
});
// array of starting positions for each curve on a uniform grid
for (var i = 0; i < N; i++) {
for (var j = 0; j < N; j++) {
Xd.push(xd[j]), Yd.push(yd[i]);
}
}

// Compute the density field in this input resolution and rescale it to output res
const logp = tf.tidy(() => {
const c = tf.concat([tf.reshape(Xd, [-1,1] ), tf.reshape(Yd, [-1,1])], axis=1);
const out = model.predict(c);
const out_resized = tf.exp(tf.image.resizeBilinear(tf.reshape(out, [N,N,1]), [width, height]));
return out_resized.dataSync();
});

// Store this array as image data
var g = d3.select("#animation").node().getContext("2d");
var imagedata = g.createImageData(width, height);
for (var x=0; x<width; x++) {
for (var y=0; y<height; y++) {
var pixelindex = (y * width + x) * 4;
// Generate a xor pattern with some random noise
var po = logp[((height -1 - y) * width + x)]*0.5;
if(isNaN(po)){ po = 0; }
c = d3.rgb(d3.interpolateInferno(po));
// Set the pixel data
imagedata.data[pixelindex] = c.r; // Red
imagedata.data[pixelindex+1] = c.g; // discretize the vfield coordsgreen; // Green
imagedata.data[pixelindex+2] = c.b; // Blue
imagedata.data[pixelindex+3] = 255; // Alpha
}
}
g.putImageData(imagedata,0,0);
for (var x=0; x<width; x++) {
for (var y=0; y<height; y++) {
var pixelindex = (y * width + x) * 4;
// Generate a xor pattern with some random noise
var po = logp[((height -1 - y) * width + x)]*0.5;
if(isNaN(po)){ po = 0; }
c = d3.rgb(d3.interpolateInferno(po));
// Set the pixel data
imagedata.data[pixelindex] = c.r; // Red
imagedata.data[pixelindex+1] = c.g; // discretize the vfield coordsgreen; // Green
imagedata.data[pixelindex+2] = c.b; // Blue
imagedata.data[pixelindex+3] = 25; // Alpha
}
}

var N = 50;
var xp = d3.range(N).map(
function (i) {
return -1.5 + xb*i/N;
}),
yp = d3.range(N).map(
function (i) {
return -1 + yb*i/N;
});
// array of starting positions for each curve on a uniform grid
for (var i = 0; i < N; i++) {
for (var j = 0; j < N; j++) {
X.push(xp[j]), Y.push(yp[i]);
X0.push(xp[j]), Y0.push(yp[i]);
}
}

// // vfield
function F(x, y) {
const [px, py] = tf.tidy(() => {
const [predx, predy] = vfunc(x, y);
return [predx.dataSync(), predy.dataSync()];
});
return [px, py];
}

//// frame setup
var mw = 0;

g.lineWidth = 0.8;
g.strokeStyle = "#FF8000"; // html color code

//// mapping from vfield coords to web page coords
var xMap = d3.scaleLinear()
.domain([-1.5, 2.5])
.range([mw, width - mw]),
yMap = d3.scaleLinear()
.domain([-1, 2.])
.range([height - mw, mw]);
//// animation setup
var animAge = 0,
frameRate = 30, // ms per timestep (yeah I know it's not really a rate)
M = X.length,
thr=200,
MaxAge = 100, // # timesteps before restart
age = [];

for (var i=0; i<M; i++) {age.push(randage());}
var drawFlag = false;

d3.timer(function () {if (drawFlag) {draw();}}, frameRate);
d3.select("#animation")
.on("click", function() {
var mouse = d3.mouse(this);
XC = xMap.invert(mouse[0]);
YC = yMap.invert(mouse[1]);
})

d3.select("body").on("keypress", function() {
if(d3.event.keyCode === 32 || d3.event.keyCode === 13){
drawFlag = (drawFlag) ? false : true;
}
if(d3.event.keyCode === 61 ){
sigma = sigma*2.;
}
if(d3.event.keyCode === 45 ){
sigma /= 2.;
}
})
function randage() {
// to randomize starting ages for each curve
return Math.round(Math.random()*100);
}

var overlayCanvas = document.createElement("canvas");
overlayCanvas.width = width;
overlayCanvas.height = height;
overlayCanvas.getContext("2d").putImageData(imagedata, 0, 0);
g.imageSmoothingEnabled = false;

// for info on the global canvas operations see
// http://bucephalus.org/text/CanvasHandbook/CanvasHandbook.html#globalcompositeoperation
g.globalCompositeOperation = "source-over";
function draw() {
var s = (xMap(sigma) - xMap(0));
//g.fillRect(0, 0, width, height); // fades all existing curves by a set amount determined by fillStyle (above), which sets opacity using rgba
//g.putImageData(imagedata,0,0);
g.drawImage(overlayCanvas,0,0);
// Compute dr for all points
g.lineWidth = 1.5;
g.strokeStyle = "#FF8000"; // html color code
var [dx, dy] = F(X, Y);
for (var i=0; i<M; i++) { // draw a single timestep for every curve
// if dx dy is larger than our threshold, we don't need to move this point
if((dx[i]**2 + dy[i]**2) < thr){
g.beginPath();
g.moveTo(xMap(X[i]), yMap(Y[i])); // the start point of the path
g.lineTo(xMap(X[i]+=dx[i]*dt), yMap(Y[i]+=dy[i]*dt)); // the end point
g.stroke(); // final draw command
};
if (age[i]++ > MaxAge) {
// incriment age of each curve, restart if MaxAge is reached
age[i] = randage();
X[i] = X0[i], Y[i] = Y0[i];
}
}
// Computes gradients of the solution
var [dx, dy] = F([X1, X2], [Y1, Y2]);
dx[0]+= 0.5*(XC - (X1+X2))/sigma/sigma;
dx[1]+= 0.5*(XC - (X1+X2))/sigma/sigma;
dy[0]+= 0.5*(YC - (Y1+Y2))/sigma/sigma;
dy[1]+= 0.5*(YC - (Y1+Y2))/sigma/sigma;

// Draw solution points
g.lineWidth = 14;
g.strokeStyle = g.fillStyle = "#ADFF2F"; // html color code
XS=X1+X2; YS=Y1+Y2;
g.beginPath();
g.moveTo(xMap(X1), yMap(Y1));
g.lineTo(xMap(X1+=dx[0]*dt), yMap(Y1+=dy[0]*dt));
g.stroke();
g.beginPath();
g.arc(xMap(X1), yMap(Y1), 7, 0, 2 * Math.PI);
g.fill();

g.strokeStyle = g.fillStyle = "#96CDFF"; // html color code
g.beginPath();
g.moveTo(xMap(X2), yMap(Y2));
g.lineTo(xMap(X2+=dx[1]*dt), yMap(Y2+=dy[1]*dt));
g.stroke();
g.beginPath();
g.arc(xMap(X2), yMap(Y2), 7, 0, 2 * Math.PI);
g.fill();

g.strokeStyle = g.fillStyle = "#E32E52";//#896ED1"; // html color code
g.beginPath();
g.moveTo(xMap(XS), yMap(YS));
XS=X1+X2; YS=Y1+Y2;
g.lineTo(xMap(XS), yMap(YS));
g.stroke();
g.beginPath();
g.arc(xMap(XS), yMap(YS), 7, 0, 2 * Math.PI);
g.fill();

g.beginPath();
g.strokeStyle = "#C94277";
g.lineWidth = 1.5;
g.arc(xMap(XC), yMap(YC), s, 0, 2 * Math.PI);
g.stroke();

}
})()

</script>
</body>

</html>
Binary file not shown.
1 change: 1 addition & 0 deletions Seattle2019/models/js/export3/model.json

Large diffs are not rendered by default.

0 comments on commit 3a9ead8

Please sign in to comment.