我正在尝试使用Scikit-Learn训练KMeans模型。我已经为这个问题困扰了两天。即使我指定了两个列,Pandas仍然选择数据框的所有列。这里是数据框以.to_dict()形式展示的示例:
{'area': {0: 15.26, 1: 14.88, 2: 14.29, 3: 13.84, 4: 16.14, 5: 14.38, 6: 14.69, 7: 16.63, 8: 16.44, 9: 15.26, 10: 14.03, 11: 13.89, 12: 13.78, 13: 13.74, 14: 14.59, 15: 13.99, 16: 15.69, 17: 14.7, 18: 12.72, 19: 14.16, 20: 14.11, 21: 15.88, 22: 12.08, 23: 15.01, 24: 16.19, 25: 13.02, 26: 12.74, 27: 14.11, 28: 13.45, 29: 13.16, 30: 15.49, 31: 14.09, 32: 13.94, 33: 15.05, 34: 16.2, 35: 17.08, 36: 14.8, 37: 14.28, 38: 13.54, 39: 13.5, 40: 13.16, 41: 15.5, 42: 15.11, 43: 13.8, 44: 15.36, 45: 14.99, 46: 14.79, 47: 14.86, 48: 14.43, 49: 15.78, 50: 14.49, 51: 14.33, 52: 14.52, 53: 15.03, 54: 14.46, 55: 14.92, 56: 15.38, 57: 12.11, 58: 11.23, 59: 12.36, 60: 13.22, 61: 12.78, 62: 12.88, 63: 14.34, 64: 14.01, 65: 12.73, 66: 17.63, 67: 16.84, 68: 17.26, 69: 19.11, 70: 16.82, 71: 16.77, 72: 17.32, 73: 20.71, 74: 18.94, 75: 17.12, 76: 16.53, 77: 18.72, 78: 20.2, 79: 19.57, 80: 19.51, 81: 18.27, 82: 18.88, 83: 18.98, 84: 21.18, 85: 20.88, 86: 20.1, 87: 18.76, 88: 18.81, 89: 18.59, 90: 18.36, 91: 16.87, 92: 19.31, 93: 18.98, 94: 18.17, 95: 18.72, 96: 16.41, 97: 17.99, 98: 19.46, 99: 19.18, 100: 18.95, 101: 18.83, 102: 17.63, 103: 19.94, 104: 18.55, 105: 18.45, 106: 19.38, 107: 19.13, 108: 19.14, 109: 20.97, 110: 19.06, 111: 18.96, 112: 19.15, 113: 18.89, 114: 20.03, 115: 20.24, 116: 18.14, 117: 16.17, 118: 18.43, 119: 15.99, 120: 18.75, 121: 18.65, 122: 17.98, 123: 20.16, 124: 17.55, 125: 18.3, 126: 18.94, 127: 15.38, 128: 16.16, 129: 15.56, 130: 17.36, 131: 15.57, 132: 15.6, 133: 16.23, 134: 13.07, 135: 13.32, 136: 13.34, 137: 12.22, 138: 11.82, 139: 11.21, 140: 11.43, 141: 12.49, 142: 12.7, 143: 10.79, 144: 11.83, 145: 12.01, 146: 12.26, 147: 11.18, 148: 11.36, 149: 11.19, 150: 11.34, 151: 12.13, 152: 11.75, 153: 11.49, 154: 12.54, 155: 12.02, 156: 12.05, 157: 12.55, 158: 11.14, 159: 12.1, 160: 12.44, 161: 12.15, 162: 11.35, 163: 11.55, 164: 11.4, 165: 10.83, 166: 10.8, 167: 11.26, 168: 10.74, 169: 11.48, 170: 12.21, 171: 11.41, 172: 12.46, 173: 12.19, 174: 11.65, 175: 12.89, 176: 11.56, 177: 11.81, 178: 10.91, 179: 11.23, 180: 10.59, 181: 10.93, 182: 11.27, 183: 11.87, 184: 10.82, 185: 12.11, 186: 12.8, 187: 12.79, 188: 13.37, 189: 12.62, 190: 12.76, 191: 12.38, 192: 11.18, 193: 12.37, 194: 12.19, 195: 11.23, 196: 13.2, 197: 11.84, 198: 12.3}, 'perimeter': {0: 14.84, 1: 14.57, 2: 14.09, 3: 13.94, 4: 14.99, 5: 14.21, 6: 14.49, 7: 15.46, 8: 15.25, 9: 14.85, 10: 14.16, 11: 14.02, 12: 14.06, 13: 14.05, 14: 14.28, 15: 13.83, 16: 14.75, 17: 14.21, 18: 13.57, 19: 14.4, 20: 14.26, 21: 14.9, 22: 13.23, 23: 14.76, 24: 15.16, 25: 13.76, 26: 13.67, 27: 14.18, 28: 14.02, 29: 13.82, 30: 14.94, 31: 14.41, 32: 14.17, 33: 14.68, 34: 15.27, 35: 15.38, 36: 14.52, 37: 14.17, 38: 13.85, 39: 13.85, 40: 13.55, 41: 14.86, 42: 14.54, 43: 14.04, 44: 14.76, 45: 14.56, 46: 14.52, 47: 14.67, 48: 14.4, 49: 14.91, 50: 14.61, 51: 14.28, 52: 14.6, 53: 14.77, 54: 14.35, 55: 14.43, 56: 14.77, 57: 13.47, 58: 12.63, 59: 13.19, 60: 13.84, 61: 13.57, 62: 13.5, 63: 14.37, 64: 14.29, 65: 13.75, 66: 15.98, 67: 15.67, 68: 15.73, 69: 16.26, 70: 15.51, 71: 15.62, 72: 15.91, 73: 17.23, 74: 16.49, 75: 15.55, 76: 15.34, 77: 16.19, 78: 16.89, 79: 16.74, 80: 16.71, 81: 16.09, 82: 16.26, 83: 16.66, 84: 17.21, 85: 17.05, 86: 16.99, 87: 16.2, 88: 16.29, 89: 16.05, 90: 16.52, 91: 15.65, 92: 16.59, 93: 16.57, 94: 16.26, 95: 16.34, 96: 15.25, 97: 15.86, 98: 16.5, 99: 16.63, 100: 16.42, 101: 16.29, 102: 15.86, 103: 16.92, 104: 16.22, 105: 16.12, 106: 16.72, 107: 16.31, 108: 16.61, 109: 17.25, 110: 16.45, 111: 16.2, 112: 16.45, 113: 16.23, 114: 16.9, 115: 16.91, 116: 16.12, 117: 15.38, 118: 15.97, 119: 14.89, 120: 16.18, 121: 16.41, 122: 15.85, 123: 17.03, 124: 15.66, 125: 15.89, 126: 16.32, 127: 14.9, 128: 15.33, 129: 14.89, 130: 15.76, 131: 15.15, 132: 15.11, 133: 15.18, 134: 13.92, 135: 13.94, 136: 13.95, 137: 13.32, 138: 13.4, 139: 13.13, 140: 13.13, 141: 13.46, 142: 13.71, 143: 12.93, 144: 13.23, 145: 13.52, 146: 13.6, 147: 13.04, 148: 13.05, 149: 13.05, 150: 12.87, 151: 13.73, 152: 13.52, 153: 13.22, 154: 13.67, 155: 13.33, 156: 13.41, 157: 13.57, 158: 12.79, 159: 13.15, 160: 13.59, 161: 13.45, 162: 13.12, 163: 13.1, 164: 13.08, 165: 12.96, 166: 12.57, 167: 13.01, 168: 12.73, 169: 13.05, 170: 13.47, 171: 12.95, 172: 13.41, 173: 13.36, 174: 13.07, 175: 13.77, 176: 13.31, 177: 13.45, 178: 12.8, 179: 12.82, 180: 12.41, 181: 12.8, 182: 12.86, 183: 13.02, 184: 12.83, 185: 13.27, 186: 13.47, 187: 13.53, 188: 13.78, 189: 13.67, 190: 13.38, 191: 13.44, 192: 12.72, 193: 13.47, 194: 13.2, 195: 12.88, 196: 13.66, 197: 13.21, 198: 13.34}, 'compactness': {0: 0.871, 1: 0.8811, 2: 0.905, 3: 0.8955, 4: 0.9034, 5: 0.8951, 6: 0.8799, 7: 0.8747, 8: 0.888, 9: 0.8696, 10: 0.8796, 11: 0.888, 12: 0.8759, 13: 0.8744, 14: 0.8993, 15: 0.9183, 16: 0.9058, 17: 0.9153, 18: 0.8686, 19: 0.8584, 20: 0.8722, 21: 0.8988, 22: 0.8664, 23: 0.8657, 24: 0.8849, 25: 0.8641, 26: 0.8564, 27: 0.882, 28: 0.8604, 29: 0.8662, 30: 0.8724, 31: 0.8529, 32: 0.8728, 33: 0.8779, 34: 0.8734, 35: 0.9079, 36: 0.8823, 37: 0.8944, 38: 0.8871, 39: 0.8852, 40: 0.9009, 41: 0.882, 42: 0.8986, 43: 0.8794, 44: 0.8861, 45: 0.8883, 46: 0.8819, 47: 0.8676, 48: 0.8751, 49: 0.8923, 50: 0.8538, 51: 0.8831, 52: 0.8557, 53: 0.8658, 54: 0.8818, 55: 0.9006, 56: 0.8857, 57: 0.8392, 58: 0.884, 59: 0.8923, 60: 0.868, 61: 0.8716, 62: 0.8879, 63: 0.8726, 64: 0.8625, 65: 0.8458, 66: 0.8673, 67: 0.8623, 68: 0.8763, 69: 0.9081, 70: 0.8786, 71: 0.8638, 72: 0.8599, 73: 0.8763, 74: 0.875, 75: 0.8892, 76: 0.8823, 77: 0.8977, 78: 0.8894, 79: 0.8779, 80: 0.878, 81: 0.887, 82: 0.8969, 83: 0.859, 84: 0.8989, 85: 0.9031, 86: 0.8746, 87: 0.8984, 88: 0.8906, 89: 0.9066, 90: 0.8452, 91: 0.8648, 92: 0.8815, 93: 0.8687, 94: 0.8637, 95: 0.881, 96: 0.8866, 97: 0.8992, 98: 0.8985, 99: 0.8717, 100: 0.8829, 101: 0.8917, 102: 0.88, 103: 0.8752, 104: 0.8865, 105: 0.8921, 106: 0.8716, 107: 0.9035, 108: 0.8722, 109: 0.8859, 110: 0.8854, 111: 0.9077, 112: 0.889, 113: 0.9008, 114: 0.8811, 115: 0.8897, 116: 0.8772, 117: 0.8588, 118: 0.9077, 119: 0.9064, 120: 0.8999, 121: 0.8698, 122: 0.8993, 123: 0.8735, 124: 0.8991, 125: 0.9108, 126: 0.8942, 127: 0.8706, 128: 0.8644, 129: 0.8823, 130: 0.8785, 131: 0.8527, 132: 0.858, 133: 0.885, 134: 0.848, 135: 0.8613, 136: 0.862, 137: 0.8652, 138: 0.8274, 139: 0.8167, 140: 0.8335, 141: 0.8658, 142: 0.8491, 143: 0.8107, 144: 0.8496, 145: 0.8249, 146: 0.8333, 147: 0.8266, 148: 0.8382, 149: 0.8253, 150: 0.8596, 151: 0.8081, 152: 0.8082, 153: 0.8263, 154: 0.8425, 155: 0.8503, 156: 0.8416, 157: 0.8558, 158: 0.8558, 159: 0.8793, 160: 0.8462, 161: 0.8443, 162: 0.8291, 163: 0.8455, 164: 0.8375, 165: 0.8099, 166: 0.859, 167: 0.8355, 168: 0.8329, 169: 0.8473, 170: 0.8453, 171: 0.856, 172: 0.8706, 173: 0.8579, 174: 0.8575, 175: 0.8541, 176: 0.8198, 177: 0.8198, 178: 0.8372, 179: 0.8594, 180: 0.8648, 181: 0.839, 182: 0.8563, 183: 0.8795, 184: 0.8256, 185: 0.8639, 186: 0.886, 187: 0.8786, 188: 0.8849, 189: 0.8481, 190: 0.8964, 191: 0.8609, 192: 0.868, 193: 0.8567, 194: 0.8783, 195: 0.8511, 196: 0.8883, 197: 0.8521, 198: 0.8684}, 'length': {0: 5.763, 1: 5.554, 2: 5.291, 3: 5.324, 4: 5.658, 5: 5.386, 6: 5.563, 7: 6.053, 8: 5.884, 9: 5.714, 10: 5.438, 11: 5.439, 12: 5.479, 13: 5.482, 14: 5.351, 15: 5.119, 16: 5.527, 17: 5.205, 18: 5.226, 19: 5.658, 20: 5.52, 21: 5.618, 22: 5.099, 23: 5.789, 24: 5.833, 25: 5.395, 26: 5.395, 27: 5.541, 28: 5.516, 29: 5.454, 30: 5.757, 31: 5.717, 32: 5.585, 33: 5.712, 34: 5.826, 35: 5.832, 36: 5.656, 37: 5.397, 38: 5.348, 39: 5.351, 40: 5.138, 41: 5.877, 42: 5.579, 43: 5.376, 44: 5.701, 45: 5.57, 46: 5.545, 47: 5.678, 48: 5.585, 49: 5.674, 50: 5.715, 51: 5.504, 52: 5.741, 53: 5.702, 54: 5.388, 55: 5.384, 56: 5.662, 57: 5.159, 58: 4.902, 59: 5.076, 60: 5.395, 61: 5.262, 62: 5.139, 63: 5.63, 64: 5.609, 65: 5.412, 66: 6.191, 67: 5.998, 68: 5.978, 69: 6.154, 70: 6.017, 71: 5.927, 72: 6.064, 73: 6.579, 74: 6.445, 75: 5.85, 76: 5.875, 77: 6.006, 78: 6.285, 79: 6.384, 80: 6.366, 81: 6.173, 82: 6.084, 83: 6.549, 84: 6.573, 85: 6.45, 86: 6.581, 87: 6.172, 88: 6.272, 89: 6.037, 90: 6.666, 91: 6.139, 92: 6.341, 93: 6.449, 94: 6.271, 95: 6.219, 96: 5.718, 97: 5.89, 98: 6.113, 99: 6.369, 100: 6.248, 101: 6.037, 102: 6.033, 103: 6.675, 104: 6.153, 105: 6.107, 106: 6.303, 107: 6.183, 108: 6.259, 109: 6.563, 110: 6.416, 111: 6.051, 112: 6.245, 113: 6.227, 114: 6.493, 115: 6.315, 116: 6.059, 117: 5.762, 118: 5.98, 119: 5.363, 120: 6.111, 121: 6.285, 122: 5.979, 123: 6.513, 124: 5.791, 125: 5.979, 126: 6.144, 127: 5.884, 128: 5.845, 129: 5.776, 130: 6.145, 131: 5.92, 132: 5.832, 133: 5.872, 134: 5.472, 135: 5.541, 136: 5.389, 137: 5.224, 138: 5.314, 139: 5.279, 140: 5.176, 141: 5.267, 142: 5.386, 143: 5.317, 144: 5.263, 145: 5.405, 146: 5.408, 147: 5.22, 148: 5.175, 149: 5.25, 150: 5.053, 151: 5.394, 152: 5.444, 153: 5.304, 154: 5.451, 155: 5.35, 156: 5.267, 157: 5.333, 158: 5.011, 159: 5.105, 160: 5.319, 161: 5.417, 162: 5.176, 163: 5.167, 164: 5.136, 165: 5.278, 166: 4.981, 167: 5.186, 168: 5.145, 169: 5.18, 170: 5.357, 171: 5.09, 172: 5.236, 173: 5.24, 174: 5.108, 175: 5.495, 176: 5.363, 177: 5.413, 178: 5.088, 179: 5.089, 180: 4.899, 181: 5.046, 182: 5.091, 183: 5.132, 184: 5.18, 185: 5.236, 186: 5.16, 187: 5.224, 188: 5.32, 189: 5.41, 190: 5.073, 191: 5.219, 192: 5.009, 193: 5.204, 194: 5.137, 195: 5.14, 196: 5.236, 197: 5.175, 198: 5.243}}
这是我的代码:
cols=['area', 'perimeter', 'compactness', 'length', 'width', 'asymmetry', 'groove', 'class']
x = 'perimeter'
y = 'asymmetry'
z = df[[x, y]].values
kmeans = KMeans(n_clusters=3, n_init=10).fit(z)
clusters = kmeans.labels_
print(clusters)
cluster_df = pd.DataFrame(np.hstack((z, clusters.reshape(-1, 1))), columns=[x, y, "class"])
sns.scatterplot(x=x, y=y, hue='class', data=cluster_df)
plt.show()
sns.scatterplot(x=x, y=y, hue='class', data=df)
plt.show()
我期望预测后的图表看起来像原始数据集的图表。
预测后的图表看起来像是sns绘制了所有不同的列在一起。
附注:我是一个初学者,所以可能有些东西我不懂,请不要讨厌我。
编辑:我认为问题出在术语’cluster_df’上,但我不知道它有什么问题或者我应该怎么改。
回答:
原来是这行代码的问题:
kmeans = KMeans(n_clusters=3, n_init=10).fit(z)
应该改成:
kmeans = KMeans(n_clusters=3).fit(z)
去掉n_init=10