Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 38 additions & 20 deletions examples/mnist/mnist.js
Original file line number Diff line number Diff line change
@@ -1,25 +1,43 @@
function loadMNIST(callback) {
let mnist = {};
loadFile('t10k-images-idx3-ubyte', 16)
.then(data => {
mnist.test_images = data;
return loadFile('t10k-labels-idx1-ubyte', 8);
})
.then(data => {
mnist.test_labels = data;
return loadFile('train-images-idx3-ubyte', 16);
}).then(data => {
mnist.train_images = data;
return loadFile('train-labels-idx1-ubyte', 8);
})
.then(data => {
mnist.train_labels = data;
callback(mnist);
});
let files = {
train_images: 'train-images-idx3-ubyte',
train_labels: 'train-labels-idx1-ubyte',
test_images: 't10k-images-idx3-ubyte',
test_labels: 't10k-labels-idx1-ubyte',
};
return Promise.all(Object.keys(files).map(async file => {
mnist[file] = await loadFile(files[file])
}))
.then(() => callback(mnist));
}

async function loadFile(file, offset) {
let r = await fetch(file);
let data = await r.arrayBuffer();
return new Uint8Array(data).slice(offset);
async function loadFile(file) {
let buffer = await fetch(file).then(r => r.arrayBuffer());
let headerCount = 4;
let headerView = new DataView(buffer, 0, 4 * headerCount);
let headers = new Array(headerCount).fill().map((_, i) => headerView.getUint32(4 * i, false));

// Get file type from the magic number
let type, dataLength;
if(headers[0] == 2049) {
type = 'label';
dataLength = 1;
headerCount = 2;
} else if(headers[0] == 2051) {
type = 'image';
dataLength = headers[2] * headers[3];
} else {
throw new Error("Unknown file type " + headers[0])
}

let data = new Uint8Array(buffer, headerCount * 4);
if(type == 'image') {
dataArr = [];
for(let i = 0; i < headers[1]; i++) {
dataArr.push(data.subarray(dataLength * i, dataLength * (i + 1)));
}
return dataArr;
}
return data;
}
4 changes: 2 additions & 2 deletions examples/mnist/sketch.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function train(show) {
train_image.loadPixels();
}
for (let i = 0; i < 784; i++) {
let bright = mnist.train_images[i + train_index * 784];
let bright = mnist.train_images[train_index][i];
inputs[i] = bright / 255;
if (show) {
let index = i * 4;
Expand Down Expand Up @@ -80,7 +80,7 @@ function train(show) {
function testing() {
let inputs = [];
for (let i = 0; i < 784; i++) {
let bright = mnist.test_images[i + test_index * 784];
let bright = mnist.test_images[test_index][i];
inputs[i] = bright / 255;
}
let label = mnist.test_labels[test_index];
Expand Down