Продолжает строить и обучать нейросети в браузере с использованием tensorflow.js и сегодня классический пример для сверточных нейронных сетей — задача распознавания рукописных цифр на основе MNIST. Прямо в браузере будет обучаться нейронка плюс будет показана панель (визор) с указанием архитектуры модели и графики потерь и точности.
Код страницы:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
<html> <head> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis"></script> </head> <body> <h1>Классификатор рукописных цифр!</h1> <canvas id="canvas" width="280" height="280" style="position:absolute;top:100;left:100;border:8px solid;"></canvas> <img id="canvasimg" style="position:absolute;top:10%;left:52%;width=280;height=280;display:none;"> <input type="button" value="Классифицировать" id="sb" size="48" style="position:absolute;top:400;left:100;"> <input type="button" value="Очистить" id="cb" size="23" style="position:absolute;top:400;left:280;"> <script src="data.js" type="module"></script> <script src="script.js" type="module"></script> </body> </html> |
Код основного скрипта с нейросетью:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import {MnistData} from './data.js'; var canvas, ctx, saveButton, clearButton; var pos = {x:0, y:0}; var rawImage; var model; function getModel() { model = tf.sequential(); model.add(tf.layers.conv2d({inputShape: [28, 28, 1], kernelSize: 3, filters: 8, activation: 'relu'})); model.add(tf.layers.maxPooling2d({poolSize: [2, 2]})); model.add(tf.layers.conv2d({kernelSize: 3, filters: 16, activation: 'relu'})); model.add(tf.layers.maxPooling2d({poolSize: [2, 2]})); model.add(tf.layers.flatten()); model.add(tf.layers.dense({units: 128, activation: 'relu'})); model.add(tf.layers.dense({units: 10, activation: 'softmax'})); model.compile({optimizer: tf.train.adam(), loss: 'categoricalCrossentropy', metrics: ['accuracy']}); return model; } async function train(model, data) { const metrics = ['loss', 'val_loss', 'accuracy', 'val_accuracy']; const container = { name: 'Обучение модели', styles: { height: '640px' } }; const fitCallbacks = tfvis.show.fitCallbacks(container, metrics); const BATCH_SIZE = 512; const TRAIN_DATA_SIZE = 5500; const TEST_DATA_SIZE = 1000; const [trainXs, trainYs] = tf.tidy(() => { const d = data.nextTrainBatch(TRAIN_DATA_SIZE); return [ d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]), d.labels ]; }); const [testXs, testYs] = tf.tidy(() => { const d = data.nextTestBatch(TEST_DATA_SIZE); return [ d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]), d.labels ]; }); return model.fit(trainXs, trainYs, { batchSize: BATCH_SIZE, validationData: [testXs, testYs], epochs: 20, shuffle: true, callbacks: fitCallbacks }); } function setPosition(e){ pos.x = e.clientX-100; pos.y = e.clientY-100; } function draw(e) { if(e.buttons!=1) return; ctx.beginPath(); ctx.lineWidth = 22; ctx.lineCap = 'round'; ctx.strokeStyle = 'yellow'; ctx.moveTo(pos.x, pos.y); setPosition(e); ctx.lineTo(pos.x, pos.y); ctx.stroke(); rawImage.src = canvas.toDataURL('image/png'); } function erase() { ctx.fillStyle = "black"; ctx.fillRect(0,0,280,280); } function save() { var raw = tf.browser.fromPixels(rawImage,1); var resized = tf.image.resizeBilinear(raw, [28,28]); var tensor = resized.expandDims(0); var prediction = model.predict(tensor); var pIndex = tf.argMax(prediction, 1).dataSync(); alert(pIndex); } function init() { canvas = document.getElementById('canvas'); rawImage = document.getElementById('canvasimg'); ctx = canvas.getContext("2d"); ctx.fillStyle = "black"; ctx.fillRect(0,0,280,280); canvas.addEventListener("mousemove", draw); canvas.addEventListener("mousedown", setPosition); canvas.addEventListener("mouseenter", setPosition); saveButton = document.getElementById('sb'); saveButton.addEventListener("click", save); clearButton = document.getElementById('cb'); clearButton.addEventListener("click", erase); } async function run() { const data = new MnistData(); await data.load(); const model = getModel(); tfvis.show.modelSummary({name: 'Архитектура модели'}, model); await train(model, data); init(); alert("Обучение закончено, попробуйте классификатор на рукописных цифрах!"); } document.addEventListener('DOMContentLoaded', run); |
Файл data.js можно скачать здесь.
Видео (сверточная нейросеть в браузере)