最近は、Keras API(高水準API)を使うのが流行のようですが、ここでは TensorFlow.js の Core API(低水準API)を使って非線形回帰モデルを作成する方法を「掛け算の九九の表を学習して乗算をするニューラルネットワーク」を例にとって紹介します。Core API でも簡単に回帰モデルを作成することができます。
※ブラウザ版はこちらです。
※TensorFlow.js のインストール手順や実行方法については、こちらの記事を参考にしてください。
〈掛け算の九九の表を学習して乗算をする3層ニューラルネットワーク〉
import * as fs from "fs";
import * as os from "os";
import * as tf from "@tensorflow/tfjs";
class NeuralNetwork {
units1 = 2; /* 第1層(入力層)のユニット数 */
units2 = 7; /* 第2層(隠れ層)のユニット数 */
units3 = 1; /* 第3層(出力層)のユニット数 */
epochs = 50000; /* エポック数 */
w1 = tf.variable(tf.randomUniform([this.units1, this.units2], -1, 1));
b1 = tf.variable(tf.randomUniform([this.units2], -1, 1));
w2 = tf.variable(tf.randomUniform([this.units2, this.units3], -1, 1));
b2 = tf.variable(tf.randomUniform([this.units3], -1, 1));
model(x) {
return x.matMul(this.w1).add(this.b1).tanh().matMul(this.w2).add(this.b2);
}
loss(x, y) {
return tf.losses.meanSquaredError(this.model(x), y);
}
train(x, y) {
const optimizer = tf.train.adam(0.001, 0.9, 0.999, 0.00000001);
for (let epoch = 1; epoch <= this.epochs; epoch++) {
optimizer.minimize(() => {
const loss = this.loss(tf.tensor(x), tf.tensor(y));
if (epoch % 1000 == 0) console.log(`${epoch}epoch: loss = ${loss.arraySync()}`);
return loss;
});
}
}
predict(x) {
return this.model(tf.tensor([x])).arraySync()[0];
}
}
const neuralNetwork = new NeuralNetwork();
/* 掛け算の九九の表 */
const x_train = Array.from((function* () { for (let x1 = 1; x1 <= 9; x1++) for (let x2 = 1; x2 <= 9; x2++) yield [x1, x2]; })());
const y_train = Array.from((function* () { for (let x1 = 1; x1 <= 9; x1++) for (let x2 = 1; x2 <= 9; x2++) yield [x1 * x2]; })());
/* 学習 */
neuralNetwork.train(x_train, y_train);
/* 推論 */
for (let x1 = 1; x1 <= 9; x1++) for (let x2 = 1; x2 <= 9; x2++) console.log(`${x1} * ${x2} -> ${neuralNetwork.predict([x1, x2])}`);
/* パラメーターのファイル出力 */
const parameters = `
w1 = ${JSON.stringify(neuralNetwork.w1.arraySync())}${os.EOL}
b1 = ${JSON.stringify(neuralNetwork.b1.arraySync())}${os.EOL}
w2 = ${JSON.stringify(neuralNetwork.w2.arraySync())}${os.EOL}
b2 = ${JSON.stringify(neuralNetwork.b2.arraySync())}${os.EOL}
`;
fs.writeFileSync("parameters.dat", parameters);
注)importステートメントを使ってモジュールを呼び出しているので、ファイルの拡張子は mjs にしてください。
注)3層ニューラルネットワークのモデル関数及び w1、b1、w2、b2 については、ニューラルネットワークで使う数式のまとめ を参照してください。
注)train(x, y)及びpredict(x)の引数は、テンソルではなくJavaScriptの配列です。関数本体の方でtf.tensor(・・・)を使ってテンソルに変換しています。
結論
TensorFlow.js の Core API はとても使い易いです!ニューラルネットワーク以外の回帰分析に使えるのもメリットの一つです。 TensorFlow.js の登場で、これまで Python の独壇場だった AI の分野で JavaScript が使われるようになりそうですね。
参考記事
- TensorFlow
- TensorFlow.js
- TensorFlow.js - インストール
- TensorFlow.js - Node.jsで利用
- TensorFlow.js - トレーニングモデル
- TensorFlow.js API
- TensorFlow.js API - Data
- TensorFlow.js API - Optimizers
- npm - @tensorflow/tfjs