diff --git a/examples/mnist/mnist.js b/examples/mnist/mnist.js index fe3306d..6facc9d 100644 --- a/examples/mnist/mnist.js +++ b/examples/mnist/mnist.js @@ -1,25 +1,43 @@ function loadMNIST(callback) { let mnist = {}; - loadFile('t10k-images-idx3-ubyte', 16) - .then(data => { - mnist.test_images = data; - return loadFile('t10k-labels-idx1-ubyte', 8); - }) - .then(data => { - mnist.test_labels = data; - return loadFile('train-images-idx3-ubyte', 16); - }).then(data => { - mnist.train_images = data; - return loadFile('train-labels-idx1-ubyte', 8); - }) - .then(data => { - mnist.train_labels = data; - callback(mnist); - }); + let files = { + train_images: 'train-images-idx3-ubyte', + train_labels: 'train-labels-idx1-ubyte', + test_images: 't10k-images-idx3-ubyte', + test_labels: 't10k-labels-idx1-ubyte', + }; + return Promise.all(Object.keys(files).map(async file => { + mnist[file] = await loadFile(files[file]) + })) + .then(() => callback(mnist)); } -async function loadFile(file, offset) { - let r = await fetch(file); - let data = await r.arrayBuffer(); - return new Uint8Array(data).slice(offset); +async function loadFile(file) { + let buffer = await fetch(file).then(r => r.arrayBuffer()); + let headerCount = 4; + let headerView = new DataView(buffer, 0, 4 * headerCount); + let headers = new Array(headerCount).fill().map((_, i) => headerView.getUint32(4 * i, false)); + + // Get file type from the magic number + let type, dataLength; + if(headers[0] == 2049) { + type = 'label'; + dataLength = 1; + headerCount = 2; + } else if(headers[0] == 2051) { + type = 'image'; + dataLength = headers[2] * headers[3]; + } else { + throw new Error("Unknown file type " + headers[0]) + } + + let data = new Uint8Array(buffer, headerCount * 4); + if(type == 'image') { + dataArr = []; + for(let i = 0; i < headers[1]; i++) { + dataArr.push(data.subarray(dataLength * i, dataLength * (i + 1))); + } + return dataArr; + } + return data; } diff --git a/examples/mnist/sketch.js b/examples/mnist/sketch.js index bacae86..6d25576 100644 --- a/examples/mnist/sketch.js +++ b/examples/mnist/sketch.js @@ -37,7 +37,7 @@ function train(show) { train_image.loadPixels(); } for (let i = 0; i < 784; i++) { - let bright = mnist.train_images[i + train_index * 784]; + let bright = mnist.train_images[train_index][i]; inputs[i] = bright / 255; if (show) { let index = i * 4; @@ -80,7 +80,7 @@ function train(show) { function testing() { let inputs = []; for (let i = 0; i < 784; i++) { - let bright = mnist.test_images[i + test_index * 784]; + let bright = mnist.test_images[test_index][i]; inputs[i] = bright / 255; } let label = mnist.test_labels[test_index];