はじめに

機械学習の勉強をしていて、一番困ったのはtensorflowがGPUを使ってくれない問題でした。 半年以上模索して、Windows上で動かせるようにしたのでちょっと紹介。もし、環境作りしたい人がいればまた記事にします(かなり面倒だったので今回は割愛)。機械学習での”Hello world”的なmnist画像の手書き文字分類をしてみます。

今回の環境

os:Windows10

GPU:GeForce 1080ti

R:4.2.2

python:3.9.13(仮想環境)

tensorflow:2.10.1

パッケージ類の読み込み

まず、データ成型のためのtidyverseや描画用のggplotを呼び出しときます。データの取得や加工、結果の描画やレポートの作成まで機械学習の流れをすべてRでできるのが便利すぎるぜ。 それから、Rで実装されている機械学習用のパッケージもデータの加工用に便利だから読み込みます。

library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.2 ──
## ✔ ggplot2 3.4.0     ✔ purrr   1.0.1
## ✔ tibble  3.2.1     ✔ dplyr   1.1.1
## ✔ tidyr   1.3.0     ✔ stringr 1.5.0
## ✔ readr   2.1.3     ✔ forcats 0.5.2
## Warning: パッケージ 'tibble' はバージョン 4.2.3 の R の下で造られました
## Warning: パッケージ 'tidyr' はバージョン 4.2.3 の R の下で造られました
## Warning: パッケージ 'purrr' はバージョン 4.2.3 の R の下で造られました
## Warning: パッケージ 'dplyr' はバージョン 4.2.3 の R の下で造られました
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
library(ggplot2)
library(magrittr)
## 
##  次のパッケージを付け加えます: 'magrittr' 
## 
##  以下のオブジェクトは 'package:purrr' からマスクされています:
## 
##     set_names
## 
##  以下のオブジェクトは 'package:tidyr' からマスクされています:
## 
##     extract
library(mlbench)

それからR側からtensorflowを動かすためにtensorflowRパッケージとkerasパッケージを呼び出します。

library(tensorflow)
library(keras)

これで準備は完了。

手書き画像分類

この下の画像みたいな0から9のどれかが書かれた手書き文字をたくさん集めたデータがmnistです。これを機械になんて書いてあるか当てさせようというのが機械学習の初歩の初歩にやることみたいになっています。

まずはデータの読み込みから。kerasパッケージにmnistデータが含まれているので、ダウンロードするだけです。

# mnistのダウンロード 
mnist <- dataset_mnist()

そしたら訓練データとテストデータに振り分ける。機械学習では学習用のデータとそれがどれくらい予測できるかテストするデータに分けます。でも、このダウンロードしたデータは最初から分けられているので名前を付けるだけでいいよん。

# 振り分ける
train_images <- mnist$train$x
train_labels <- mnist$train$y
test_images <- mnist$test$x
test_labels <- mnist$test$y

このあとデータを学習できる形に加工するんだけど長いので割愛(これもRでできるよ!)。

さあネットワークを作ります!わくわく!

# ネットワークのアーティテクチュア
network <- keras_model_sequential() %>% 
  layer_dense(units = 512, activation = "relu", input_shape = c(28*28)) %>% 
  layer_dense(units = 10, activation = "softmax")

なんとたった三行!!画像データは512個のインプットを持ち出力は0~9の10個と最もシンプルな形のネットワークです。

これをコンパイルして…

# コンパイル
network %>% compile(
  optimizer = "rmsprop",
  loss = "categorical_crossentropy",
  metrics = c("accuracy")
)

学習開始!たった二行!

model <- network %>% fit(train_images,train_labels,epochs = 5, batch_size = 128)

学習の過程が見れますね!さあテストしてみましょう。

metrics <- network %>% evaluate(test_images,test_labels)
metrics
##       loss   accuracy 
## 0.07441706 0.97780001

98%近い正答率を出しました!まあチュートリアルなんで「そうかー」程度かもしれないですが自分でできたときは結構感動します!

しかも実装までのコードがめっちゃ少ないです!

次の記事では実践的な例として犬猫判別課題をするのでそちらもご覧ください!

おまけ(詳しい環境)

sessionInfo()
## R version 4.2.2 (2022-10-31 ucrt)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 10 x64 (build 19045)
## 
## Matrix products: default
## 
## locale:
## [1] LC_COLLATE=Japanese_Japan.utf8  LC_CTYPE=Japanese_Japan.utf8   
## [3] LC_MONETARY=Japanese_Japan.utf8 LC_NUMERIC=C                   
## [5] LC_TIME=Japanese_Japan.utf8    
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] keras_2.11.0.9000      tensorflow_2.11.0.9000 mlbench_2.1-3         
##  [4] magrittr_2.0.3         forcats_0.5.2          stringr_1.5.0         
##  [7] dplyr_1.1.1            purrr_1.0.1            readr_2.1.3           
## [10] tidyr_1.3.0            tibble_3.2.1           ggplot2_3.4.0         
## [13] tidyverse_1.3.2       
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_1.0.10          here_1.0.1           lubridate_1.9.0     
##  [4] lattice_0.20-45      png_0.1-8            rprojroot_2.0.3     
##  [7] zeallot_0.1.0        assertthat_0.2.1     digest_0.6.31       
## [10] utf8_1.2.3           R6_2.5.1             cellranger_1.1.0    
## [13] backports_1.4.1      reprex_2.0.2         evaluate_0.20       
## [16] highr_0.10           httr_1.4.5           pillar_1.9.0        
## [19] tfruns_1.5.1         rlang_1.1.0          googlesheets4_1.0.1 
## [22] readxl_1.4.1         rstudioapi_0.14      whisker_0.4.1       
## [25] jquerylib_0.1.4      Matrix_1.5-1         reticulate_1.28-9000
## [28] rmarkdown_2.21       googledrive_2.0.0    munsell_0.5.0       
## [31] broom_1.0.2          compiler_4.2.2       modelr_0.1.10       
## [34] xfun_0.38            base64enc_0.1-3      pkgconfig_2.0.3     
## [37] htmltools_0.5.5      tidyselect_1.2.0     fansi_1.0.4         
## [40] crayon_1.5.2         tzdb_0.3.0           dbplyr_2.2.1        
## [43] withr_2.5.0          grid_4.2.2           jsonlite_1.8.4      
## [46] gtable_0.3.1         lifecycle_1.0.3      DBI_1.1.3           
## [49] scales_1.2.1         cli_3.6.1            stringi_1.7.12      
## [52] cachem_1.0.7         fs_1.6.1             xml2_1.3.3          
## [55] bslib_0.4.2          ellipsis_0.3.2       generics_0.1.3      
## [58] vctrs_0.6.1          tools_4.2.2          glue_1.6.2          
## [61] hms_1.1.2            fastmap_1.1.1        yaml_2.3.7          
## [64] timechange_0.1.1     colorspace_2.1-0     gargle_1.2.1        
## [67] rvest_1.0.3          knitr_1.42           haven_2.5.1         
## [70] sass_0.4.5