以前書いた記事〈Node.js版〉TensorFlow.jsで作成した非線形回帰モデルのブラウザ版です。WASMバックエンドを使って高速化を図っています。Node.js版に比べ、処理速度が30%程度落ちるレベルです。
Node.js版からの改良点は、ユニット数をコンストラクタの引数から設定できるようにしたことです。これにより、ユニット数の異なる複数のモデルを生成し実行できるようになりました。 また、エポック数も train(x, y, epochs) の引数から入力できるようにしました。
※ブラウザはGoogle Chromeがお勧めです。Mozilla Firefoxは遅いです。古いブラウザでは動作しません!
をクリックすると学習がスタートします。学習が完了すると"Completed!"と表示されるので、2ヶ所の入力フィールドに1~9の数値を入力し、をクリックすると結果が表示されます。 当たり前のことですが、学習データの範囲外の数値を入力すると、とんでもない値が出力されます。
ちなみに学習時間は、Celeron搭載の私のPCでChromeを使って3分近くかかりました。古いPCをお使いの方は気長にお待ちください^^
〈2021年10月24日追記〉
当初、WASMバックエンドが有効だったのですが、今日実行するとなぜかWASMバックエンドが効いていないことがわかりました。原因不明です。
そのため学習時間が私のPCで10分以上かかるようになりました!
〈掛け算の九九の表を学習して乗算をする3層ニューラルネットワーク〉
tensorflow.html
<!doctype html>
<html>
<body>
<div>
<button id="ButtonTrain">train</button>
<input id="Input1" type="text"> × <input id="Input2" type="text">
<button id="ButtonPredict">→</button> <output id="Output"></output>
</div>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-backend-wasm/dist/tf-backend-wasm.js"></script>
<script>
class NeuralNetwork {
constructor(units1, units2, units3) {
/* units1:第1層(入力層)のユニット数 */
/* units2:第2層(隠れ層)のユニット数 */
/* units3:第3層(出力層)のユニット数 */
this.w1 = tf.variable(tf.randomUniform([units1, units2], -1, 1));
this.b1 = tf.variable(tf.randomUniform([units2], -1, 1));
this.w2 = tf.variable(tf.randomUniform([units2, units3], -1, 1));
this.b2 = tf.variable(tf.randomUniform([units3], -1, 1));
}
model(x) {
return x.matMul(this.w1).add(this.b1).tanh().matMul(this.w2).add(this.b2);
}
loss(x, y) {
return tf.losses.meanSquaredError(this.model(x), y);
}
train(x, y, epochs) {
tf.setBackend("wasm");
tf.ready().then(() => {
const optimizer = tf.train.adam(0.001, 0.9, 0.999, 0.00000001);
for (let epoch = 1; epoch <= epochs; epoch++) {
optimizer.minimize(() => this.loss(tf.tensor(x), tf.tensor(y)));
}
alert("Completed!");
});
console.log(tf.getBackend());
}
predict(x) {
return this.model(tf.tensor([x])).arraySync()[0];
}
}
</script>
<script>
const neuralNetwork = new NeuralNetwork(units1 = 2, units2 = 7, units3 = 1);
document.getElementById("ButtonTrain").onclick = () => {
/* 掛け算の九九の表 */
const x_train = Array.from((function* () { for (let x1 = 1; x1 <= 9; x1++) for (let x2 = 1; x2 <= 9; x2++) yield [x1, x2]; })());
const y_train = Array.from((function* () { for (let x1 = 1; x1 <= 9; x1++) for (let x2 = 1; x2 <= 9; x2++) yield [x1 * x2]; })());
/* 学習 */
neuralNetwork.train(x_train, y_train, epochs = 50000);
};
document.getElementById("ButtonPredict").onclick = () => {
/* 推論 */
const x1 = Number(document.getElementById("Input1").value);
const x2 = Number(document.getElementById("Input2").value);
document.getElementById("Output").innerHTML = neuralNetwork.predict([x1, x2]);
};
</script>
</body>
</html>
参考記事
- npm:@tensorflow/tfjs
- npm:@tensorflow/tfjs-backend-wasm
- Introducing the WebAssembly backend for TensorFlow.js
- Supercharging the TensorFlow.js WebAssembly backend with SIMD and multi-threading