我正在按照这个教程学习:https://www.rpubs.com/loveb/som。这个教程展示了如何在鸢尾花数据集上使用Kohonen网络(也称为SOM,一种机器学习算法)。
我运行了教程中的以下代码:
library(kohonen) #fitting SOMslibrary(ggplot2) #plotslibrary(GGally) #plotslibrary(RColorBrewer) #colors, using predefined palettesiris_complete <-iris[complete.cases(iris),] iris_unique <- unique(iris_complete) # Remove duplicates#scale datairis.sc = scale(iris_unique[, 1:4]) #Levels/Factors cannot be scaled... But used in predictive SOM:s using xyf. Later.#build gridiris.grid = somgrid(xdim = 10, ydim=10, topo="hexagonal", toroidal = TRUE)set.seed(33) #for reproducabilityiris.som <- som(iris.sc, grid=iris.grid, rlen=700, alpha=c(0.05,0.01), keep.data = TRUE)#plot 1plot(iris.som, type="count")#plot2var <- 1 #define the variable to plotplot(iris.som, type = "property", property = getCodes(iris.som)[,var], main=colnames(getCodes(iris.som))[var], palette.name=terrain.colors)
上述代码在鸢尾花数据集上拟合了一个Kohonen网络。数据集中的每个观测值都被分配到下图中的“彩色圆圈”(也称为“神经元”)之一。
我的问题是:在这些图表中,如何识别哪些观测值被分配到哪些圆圈?假设我想知道哪些观测值属于下方用黑色三角形标记的圆圈:
这是可能的吗?目前,我正试图使用iris.som$classif
来追踪哪些点在哪个圆圈里。有没有更好的方法来做这件事?
更新:@Jonny Phelps向我展示了如何识别三角形内的观测值(见下方答案)。但我仍然不确定是否可以识别不规则形状的形式。例如:
在之前的一个帖子中(在图表上标记点(R语言)),一位用户向我展示了如何为网格上的每个圆圈分配任意编号:
基于上面的图表,如何使用“som$classif”语句来找出哪些观测值在圆圈92、91、82、81、72和71中?
谢谢
回答:
编辑:现在有了Shiny应用!
也可以使用plotly
解决方案,您可以将鼠标悬停在各个神经元上以显示相关的鸢尾花行名称(此处称为id)。基于您的iris.som
数据和Jonny Phelps的网格方法,您可以将行号作为连接的字符串分配给各个神经元,并在鼠标悬停时显示这些内容:
library(ggplot2)library(plotly)ga <- data.frame(g=iris.som$unit.classif, sample=seq_len(dim(iris.som$data[[1]])[1]))grid_pts <- as.data.frame(iris.som$grid$pts)grid_pts$column <- rep(1:iris.som$grid$xdim, by=iris.som$grid$ydim)grid_pts$row <- rep(1:iris.som$grid$ydim, each=iris.som$grid$xdim)grid_pts$classif <- 1:nrow(grid_pts)grid_pts$id <- sapply(seq_along(grid_pts$classif), function(x) paste(ga$sample[ga$g==x], collapse=", "))grid_pts$count <- sapply(seq_along(grid_pts$classif), function(x) length(ga$sample[ga$g==x]))grid_pts$count <- factor(grid_pts$count, levels=0:max(grid_pts$count))p1 <- ggplot(grid_pts, aes(x=x, y=y, colour=count, row=row, column=column, id=id)) + geom_point(size=8) + scale_colour_manual(values=c("grey50", heat.colors(length(unique(grid_pts$count))))) + theme_void() + theme(plot.margin=unit(c(1,rep(.3, 3)),"cm"))ggplotly(p1)
这里是一个完整的Shiny应用,允许套索选择并显示包含数据的表格:
invisible(suppressPackageStartupMessages( lapply(c("shiny","dplyr","ggplot2", "plotly", "kohonen", "GGally", "DT"), require, character.only=TRUE)))iris_complete <- iris[complete.cases(iris),] iris_unique <- unique(iris_complete) # Remove duplicates#scale datairis.sc = scale(iris_unique[, 1:4]) #Levels/Factors cannot be scaled... But used in predictive SOM:s using xyf. Later.#build gridiris.grid = somgrid(xdim = 10, ydim=10, topo="hexagonal", toroidal = TRUE)set.seed(33) #for reproducabilityiris.som <- som(iris.sc, grid=iris.grid, rlen=700, alpha=c(0.05,0.01), keep.data = TRUE)ga <- data.frame(g=iris.som$unit.classif, sample=seq_len(dim(iris.som$data[[1]])[1]))grid_pts <- as.data.frame(iris.som$grid$pts)grid_pts$column <- rep(1:iris.som$grid$xdim, by=iris.som$grid$ydim)grid_pts$row <- rep(1:iris.som$grid$ydim, each=iris.som$grid$xdim)grid_pts$classif <- 1:nrow(grid_pts)grid_pts$id <- sapply(seq_along(grid_pts$classif), function(x) paste(ga$sample[ga$g==x], collapse=", "))grid_pts$count <- sapply(seq_along(grid_pts$classif), function(x) length(ga$sample[ga$g==x]))grid_pts$count <- factor(grid_pts$count, levels=0:max(grid_pts$count))# Shiny app, adapted from https://gist.github.com/dgrapov/128e3be71965bf00495768e47f0428b9ui <- fluidPage( fluidRow( column(12, plotlyOutput("plot", height = "600px")), column(12, DT::dataTableOutput('data_table')) ))server <- function(input, output){ output$plot <- renderPlotly({ plot_ly(data = grid_pts, x = ~x, y = ~y, color = ~count, text = ~id, hoverinfo = "text", type = "scatter", mode = "markers", marker = list(size = 8)) %>% layout(xaxis = list(showgrid = FALSE, showticklabels = FALSE), yaxis = list(showgrid = FALSE, showticklabels = FALSE), plot_bgcolor = "transparent", paper_bgcolor = "transparent") }) output$data_table <- DT::renderDataTable({ datatable(iris_unique, filter = "top", options = list(pageLength = 10)) })}shinyApp(ui = ui, server = server)