-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from EiffL/Seattle2019
Adds missing files
- Loading branch information
Showing
3 changed files
with
273 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,272 @@ | ||
<!DOCTYPE html> | ||
<html> | ||
<head> | ||
<title>Generative Models as Priors for Inverse Problems</title> | ||
<style> | ||
body { | ||
height: 500px; | ||
width: 500px; | ||
} | ||
/* #animation { | ||
position: absolute; | ||
top: 0px; | ||
left: 0px; | ||
background: #000; | ||
} body { | ||
text-align: center; | ||
} | ||
|
||
#mynetwork { | ||
height: 500px; | ||
} */ | ||
</style> | ||
|
||
<!-- Import TensorFlow.js --> | ||
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.1.0/dist/tf.min.js"></script> | ||
<script src="https://d3js.org/d3.v5.min.js"></script> | ||
</head> | ||
|
||
<body> | ||
<canvas id="animation" height="500" width="500"></canvas> | ||
<script> | ||
|
||
(async function animation() { | ||
// This function is closely modeled on http://bl.ocks.org/newby-jay/767c5ffdbbe43b65902f | ||
const model = await tf.loadGraphModel('models/js/export3/model.json'); | ||
const grads = tf.grad(x => model.predict(x)); | ||
const vfunc = (x,y) => { | ||
c = tf.concat([tf.reshape(x, [-1,1] ), tf.reshape(y, [-1,1])], axis=1); | ||
gr = grads(c); | ||
gn = tf.sum(tf.mul(gr,gr), axis=1, keepDims=true); | ||
gr = tf.mul(tf.div(gr, gn), tf.clipByValue(gn,0,100)); | ||
|
||
return tf.split(gr, 2, axis=1); | ||
}; | ||
|
||
|
||
// vector field data | ||
var dt = 0.003, | ||
X0 = [], Y0 = [], // to store initial starting locations | ||
X = [], Y = [], // to store current point for each curve | ||
Xd=[], Yd=[], | ||
xb = 4, yb = 3; | ||
var sigma=0.2; | ||
var X1=0.5, Y1=0.5, X2=0, Y2=0; | ||
var XC=1, YC=1; | ||
var width = 500, height = 500; | ||
|
||
// First draw the modelled density in the background | ||
var N = 128 | ||
var xd = d3.range(N).map( | ||
function (i) { | ||
return -1.5 + xb*i/N; | ||
}), | ||
yd = d3.range(N).map( | ||
function (i) { | ||
return -1 + yb*i/N; | ||
}); | ||
// array of starting positions for each curve on a uniform grid | ||
for (var i = 0; i < N; i++) { | ||
for (var j = 0; j < N; j++) { | ||
Xd.push(xd[j]), Yd.push(yd[i]); | ||
} | ||
} | ||
|
||
// Compute the density field in this input resolution and rescale it to output res | ||
const logp = tf.tidy(() => { | ||
const c = tf.concat([tf.reshape(Xd, [-1,1] ), tf.reshape(Yd, [-1,1])], axis=1); | ||
const out = model.predict(c); | ||
const out_resized = tf.exp(tf.image.resizeBilinear(tf.reshape(out, [N,N,1]), [width, height])); | ||
return out_resized.dataSync(); | ||
}); | ||
|
||
// Store this array as image data | ||
var g = d3.select("#animation").node().getContext("2d"); | ||
var imagedata = g.createImageData(width, height); | ||
for (var x=0; x<width; x++) { | ||
for (var y=0; y<height; y++) { | ||
var pixelindex = (y * width + x) * 4; | ||
// Generate a xor pattern with some random noise | ||
var po = logp[((height -1 - y) * width + x)]*0.5; | ||
if(isNaN(po)){ po = 0; } | ||
c = d3.rgb(d3.interpolateInferno(po)); | ||
// Set the pixel data | ||
imagedata.data[pixelindex] = c.r; // Red | ||
imagedata.data[pixelindex+1] = c.g; // discretize the vfield coordsgreen; // Green | ||
imagedata.data[pixelindex+2] = c.b; // Blue | ||
imagedata.data[pixelindex+3] = 255; // Alpha | ||
} | ||
} | ||
g.putImageData(imagedata,0,0); | ||
for (var x=0; x<width; x++) { | ||
for (var y=0; y<height; y++) { | ||
var pixelindex = (y * width + x) * 4; | ||
// Generate a xor pattern with some random noise | ||
var po = logp[((height -1 - y) * width + x)]*0.5; | ||
if(isNaN(po)){ po = 0; } | ||
c = d3.rgb(d3.interpolateInferno(po)); | ||
// Set the pixel data | ||
imagedata.data[pixelindex] = c.r; // Red | ||
imagedata.data[pixelindex+1] = c.g; // discretize the vfield coordsgreen; // Green | ||
imagedata.data[pixelindex+2] = c.b; // Blue | ||
imagedata.data[pixelindex+3] = 25; // Alpha | ||
} | ||
} | ||
|
||
var N = 50; | ||
var xp = d3.range(N).map( | ||
function (i) { | ||
return -1.5 + xb*i/N; | ||
}), | ||
yp = d3.range(N).map( | ||
function (i) { | ||
return -1 + yb*i/N; | ||
}); | ||
// array of starting positions for each curve on a uniform grid | ||
for (var i = 0; i < N; i++) { | ||
for (var j = 0; j < N; j++) { | ||
X.push(xp[j]), Y.push(yp[i]); | ||
X0.push(xp[j]), Y0.push(yp[i]); | ||
} | ||
} | ||
|
||
// // vfield | ||
function F(x, y) { | ||
const [px, py] = tf.tidy(() => { | ||
const [predx, predy] = vfunc(x, y); | ||
return [predx.dataSync(), predy.dataSync()]; | ||
}); | ||
return [px, py]; | ||
} | ||
|
||
//// frame setup | ||
var mw = 0; | ||
|
||
g.lineWidth = 0.8; | ||
g.strokeStyle = "#FF8000"; // html color code | ||
|
||
//// mapping from vfield coords to web page coords | ||
var xMap = d3.scaleLinear() | ||
.domain([-1.5, 2.5]) | ||
.range([mw, width - mw]), | ||
yMap = d3.scaleLinear() | ||
.domain([-1, 2.]) | ||
.range([height - mw, mw]); | ||
//// animation setup | ||
var animAge = 0, | ||
frameRate = 30, // ms per timestep (yeah I know it's not really a rate) | ||
M = X.length, | ||
thr=200, | ||
MaxAge = 100, // # timesteps before restart | ||
age = []; | ||
|
||
for (var i=0; i<M; i++) {age.push(randage());} | ||
var drawFlag = false; | ||
|
||
d3.timer(function () {if (drawFlag) {draw();}}, frameRate); | ||
d3.select("#animation") | ||
.on("click", function() { | ||
var mouse = d3.mouse(this); | ||
XC = xMap.invert(mouse[0]); | ||
YC = yMap.invert(mouse[1]); | ||
}) | ||
|
||
d3.select("body").on("keypress", function() { | ||
if(d3.event.keyCode === 32 || d3.event.keyCode === 13){ | ||
drawFlag = (drawFlag) ? false : true; | ||
} | ||
if(d3.event.keyCode === 61 ){ | ||
sigma = sigma*2.; | ||
} | ||
if(d3.event.keyCode === 45 ){ | ||
sigma /= 2.; | ||
} | ||
}) | ||
function randage() { | ||
// to randomize starting ages for each curve | ||
return Math.round(Math.random()*100); | ||
} | ||
|
||
var overlayCanvas = document.createElement("canvas"); | ||
overlayCanvas.width = width; | ||
overlayCanvas.height = height; | ||
overlayCanvas.getContext("2d").putImageData(imagedata, 0, 0); | ||
g.imageSmoothingEnabled = false; | ||
|
||
// for info on the global canvas operations see | ||
// http://bucephalus.org/text/CanvasHandbook/CanvasHandbook.html#globalcompositeoperation | ||
g.globalCompositeOperation = "source-over"; | ||
function draw() { | ||
var s = (xMap(sigma) - xMap(0)); | ||
//g.fillRect(0, 0, width, height); // fades all existing curves by a set amount determined by fillStyle (above), which sets opacity using rgba | ||
//g.putImageData(imagedata,0,0); | ||
g.drawImage(overlayCanvas,0,0); | ||
// Compute dr for all points | ||
g.lineWidth = 1.5; | ||
g.strokeStyle = "#FF8000"; // html color code | ||
var [dx, dy] = F(X, Y); | ||
for (var i=0; i<M; i++) { // draw a single timestep for every curve | ||
// if dx dy is larger than our threshold, we don't need to move this point | ||
if((dx[i]**2 + dy[i]**2) < thr){ | ||
g.beginPath(); | ||
g.moveTo(xMap(X[i]), yMap(Y[i])); // the start point of the path | ||
g.lineTo(xMap(X[i]+=dx[i]*dt), yMap(Y[i]+=dy[i]*dt)); // the end point | ||
g.stroke(); // final draw command | ||
}; | ||
if (age[i]++ > MaxAge) { | ||
// incriment age of each curve, restart if MaxAge is reached | ||
age[i] = randage(); | ||
X[i] = X0[i], Y[i] = Y0[i]; | ||
} | ||
} | ||
// Computes gradients of the solution | ||
var [dx, dy] = F([X1, X2], [Y1, Y2]); | ||
dx[0]+= 0.5*(XC - (X1+X2))/sigma/sigma; | ||
dx[1]+= 0.5*(XC - (X1+X2))/sigma/sigma; | ||
dy[0]+= 0.5*(YC - (Y1+Y2))/sigma/sigma; | ||
dy[1]+= 0.5*(YC - (Y1+Y2))/sigma/sigma; | ||
|
||
// Draw solution points | ||
g.lineWidth = 14; | ||
g.strokeStyle = g.fillStyle = "#ADFF2F"; // html color code | ||
XS=X1+X2; YS=Y1+Y2; | ||
g.beginPath(); | ||
g.moveTo(xMap(X1), yMap(Y1)); | ||
g.lineTo(xMap(X1+=dx[0]*dt), yMap(Y1+=dy[0]*dt)); | ||
g.stroke(); | ||
g.beginPath(); | ||
g.arc(xMap(X1), yMap(Y1), 7, 0, 2 * Math.PI); | ||
g.fill(); | ||
|
||
g.strokeStyle = g.fillStyle = "#96CDFF"; // html color code | ||
g.beginPath(); | ||
g.moveTo(xMap(X2), yMap(Y2)); | ||
g.lineTo(xMap(X2+=dx[1]*dt), yMap(Y2+=dy[1]*dt)); | ||
g.stroke(); | ||
g.beginPath(); | ||
g.arc(xMap(X2), yMap(Y2), 7, 0, 2 * Math.PI); | ||
g.fill(); | ||
|
||
g.strokeStyle = g.fillStyle = "#E32E52";//#896ED1"; // html color code | ||
g.beginPath(); | ||
g.moveTo(xMap(XS), yMap(YS)); | ||
XS=X1+X2; YS=Y1+Y2; | ||
g.lineTo(xMap(XS), yMap(YS)); | ||
g.stroke(); | ||
g.beginPath(); | ||
g.arc(xMap(XS), yMap(YS), 7, 0, 2 * Math.PI); | ||
g.fill(); | ||
|
||
g.beginPath(); | ||
g.strokeStyle = "#C94277"; | ||
g.lineWidth = 1.5; | ||
g.arc(xMap(XC), yMap(YC), s, 0, 2 * Math.PI); | ||
g.stroke(); | ||
|
||
} | ||
})() | ||
|
||
</script> | ||
</body> | ||
|
||
</html> |
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.