通过颜色识别点

我正在按照这个教程学习:https://www.rpubs.com/loveb/som。这个教程展示了如何在鸢尾花数据集上使用Kohonen网络(也称为SOM,一种机器学习算法)。

我运行了教程中的以下代码:

library(kohonen) #fitting SOMslibrary(ggplot2) #plotslibrary(GGally) #plotslibrary(RColorBrewer) #colors, using predefined palettesiris_complete <-iris[complete.cases(iris),] iris_unique <- unique(iris_complete) # Remove duplicates#scale datairis.sc = scale(iris_unique[, 1:4]) #Levels/Factors cannot be scaled... But used in predictive SOM:s using xyf. Later.#build gridiris.grid = somgrid(xdim = 10, ydim=10, topo="hexagonal", toroidal = TRUE)set.seed(33) #for reproducabilityiris.som <- som(iris.sc, grid=iris.grid, rlen=700, alpha=c(0.05,0.01), keep.data = TRUE)#plot 1plot(iris.som, type="count")#plot2var <- 1 #define the variable to plotplot(iris.som, type = "property", property = getCodes(iris.som)[,var], main=colnames(getCodes(iris.som))[var], palette.name=terrain.colors)

上述代码在鸢尾花数据集上拟合了一个Kohonen网络。数据集中的每个观测值都被分配到下图中的“彩色圆圈”(也称为“神经元”)之一。

我的问题是:在这些图表中,如何识别哪些观测值被分配到哪些圆圈?假设我想知道哪些观测值属于下方用黑色三角形标记的圆圈:

enter image description hereenter image description here

这是可能的吗?目前,我正试图使用iris.som$classif来追踪哪些点在哪个圆圈里。有没有更好的方法来做这件事?

更新:@Jonny Phelps向我展示了如何识别三角形内的观测值(见下方答案)。但我仍然不确定是否可以识别不规则形状的形式。例如:enter image description here

在之前的一个帖子中(在图表上标记点(R语言)),一位用户向我展示了如何为网格上的每个圆圈分配任意编号:

enter image description here

基于上面的图表,如何使用“som$classif”语句来找出哪些观测值在圆圈92、91、82、81、72和71中?

谢谢


回答:

编辑:现在有了Shiny应用!

也可以使用plotly解决方案,您可以将鼠标悬停在各个神经元上以显示相关的鸢尾花行名称(此处称为id)。基于您的iris.som数据和Jonny Phelps的网格方法,您可以将行号作为连接的字符串分配给各个神经元,并在鼠标悬停时显示这些内容:

library(ggplot2)library(plotly)ga <- data.frame(g=iris.som$unit.classif,                  sample=seq_len(dim(iris.som$data[[1]])[1]))grid_pts <- as.data.frame(iris.som$grid$pts)grid_pts$column <- rep(1:iris.som$grid$xdim, by=iris.som$grid$ydim)grid_pts$row <- rep(1:iris.som$grid$ydim, each=iris.som$grid$xdim)grid_pts$classif <- 1:nrow(grid_pts)grid_pts$id <- sapply(seq_along(grid_pts$classif),                       function(x) paste(ga$sample[ga$g==x], collapse=", "))grid_pts$count <- sapply(seq_along(grid_pts$classif),                          function(x) length(ga$sample[ga$g==x]))grid_pts$count <- factor(grid_pts$count, levels=0:max(grid_pts$count))p1 <- ggplot(grid_pts, aes(x=x, y=y, colour=count, row=row, column=column, id=id)) +    geom_point(size=8) +    scale_colour_manual(values=c("grey50", heat.colors(length(unique(grid_pts$count))))) +    theme_void() +    theme(plot.margin=unit(c(1,rep(.3, 3)),"cm"))ggplotly(p1)

这里是一个完整的Shiny应用,允许套索选择并显示包含数据的表格:

invisible(suppressPackageStartupMessages(    lapply(c("shiny","dplyr","ggplot2", "plotly", "kohonen", "GGally", "DT"),           require, character.only=TRUE)))iris_complete <- iris[complete.cases(iris),] iris_unique <- unique(iris_complete) # Remove duplicates#scale datairis.sc = scale(iris_unique[, 1:4]) #Levels/Factors cannot be scaled... But used in predictive SOM:s using xyf. Later.#build gridiris.grid = somgrid(xdim = 10, ydim=10, topo="hexagonal", toroidal = TRUE)set.seed(33) #for reproducabilityiris.som <- som(iris.sc, grid=iris.grid, rlen=700, alpha=c(0.05,0.01), keep.data = TRUE)ga <- data.frame(g=iris.som$unit.classif,                  sample=seq_len(dim(iris.som$data[[1]])[1]))grid_pts <- as.data.frame(iris.som$grid$pts)grid_pts$column <- rep(1:iris.som$grid$xdim, by=iris.som$grid$ydim)grid_pts$row <- rep(1:iris.som$grid$ydim, each=iris.som$grid$xdim)grid_pts$classif <- 1:nrow(grid_pts)grid_pts$id <- sapply(seq_along(grid_pts$classif),                       function(x) paste(ga$sample[ga$g==x], collapse=", "))grid_pts$count <- sapply(seq_along(grid_pts$classif),                          function(x) length(ga$sample[ga$g==x]))grid_pts$count <- factor(grid_pts$count, levels=0:max(grid_pts$count))# Shiny app, adapted from https://gist.github.com/dgrapov/128e3be71965bf00495768e47f0428b9ui <- fluidPage(    fluidRow(        column(12, plotlyOutput("plot", height = "600px")),        column(12, DT::dataTableOutput('data_table'))    ))server <- function(input, output){    output$plot <- renderPlotly({        plot_ly(data = grid_pts, x = ~x, y = ~y, color = ~count, text = ~id, hoverinfo = "text", type = "scatter", mode = "markers", marker = list(size = 8)) %>%            layout(xaxis = list(showgrid = FALSE, showticklabels = FALSE),                   yaxis = list(showgrid = FALSE, showticklabels = FALSE),                   plot_bgcolor = "transparent",                   paper_bgcolor = "transparent")    })    output$data_table <- DT::renderDataTable({        datatable(iris_unique, filter = "top", options = list(pageLength = 10))    })}shinyApp(ui = ui, server = server)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注