# 用来打印公司名称和它们所分配的集群的实用函数 defprint_clusters(df_combined,cluster_labels): cluster_dict = {} for i, label in enumerate(cluster_labels): if label notin cluster_dict: cluster_dict[label] = [] cluster_dict[label].append(df_combined.columns[i])
# 打印出每个群组中的公司 -- 建议关注@公众号:数据STUDIO 定时推送更多优质内容 for cluster, companies in cluster_dict.items(): print(f"Cluster {cluster}: {', '.join(companies)}")
defplot_cluster_heatmaps(cluster_results, companies): """ Plots the heatmaps of clustering for all companies for different methods side by side.
Args: - cluster_results: a dictionary of cluster labels for each clustering method - companies: a list of company names - 建议关注@公众号:数据STUDIO 定时推送更多优质内容 """ # 从字典中提取key和value methods = list(cluster_results.keys()) labels = list(cluster_results.values())
# 定义每个方法的热图数据 heatmaps = [] for i in range(len(methods)): heatmap = np.zeros((len(np.unique(labels[i])), len(companies))) for j in range(len(companies)): heatmap[labels[i][j], j] = 1 heatmaps.append(heatmap)
# Plot the heatmaps in a 2x2 grid fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12, 12))
for i in range(len(methods)): row = i // 2 col = i % 2 sns.heatmap(heatmaps[i], cmap="Blues", annot=True, fmt="g", xticklabels=companies, ax=axs[row, col]) axs[row, col].set_title(methods[i])