Py学习  »  Python

【Python】必用 !绝美三维散点图 !!

机器学习初学者 • 1 月前 • 177 次点击  

今儿和大家分享的是关于经常用到的三维散点图。

我们会从 5 种不同样式的绘制,包括 Matplotlib 多种风格、Plotly 交互、PyVista/VTK 高级渲染、密度/等值面可视化与多面板比较图等等方面进行阐述~

本文案例中,三维点会导出高分辨率栅格图(PNG/TIFF,600 dpi)或交互图,另外矢量格式(PDF/SVG)对大量点可能不理想。

首先,我们需要先给出一个通用数据生成与分析工具函数,后续示例均以此数据为输入。

# 数据集

大家可以直接使用,代码实现:

# common_data.py
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.neighbors import KernelDensity

def generate_synthetic_data(seed=42, n=1200):
    """
    生成用于论文示例的三维数据:包含若干簇 + 噪声 + 渐变属性。
    返回 DataFrame: cols = ['x','y','z','value']
    """

    rng = np.random.RandomState(seed)
    # 三个簇 + 噪声
    centers = np.array([
        [0.00.00.0],
        [3.52.01.0],
        [-2.03.04.0],
    ])
    rows = []
    for i, c in enumerate(centers):
        pts = c + 0.7 * rng.randn(n//43) + (i * 0.2)  # slight shift
        rows.append(pts)
    # 一些随机分布的点
    rand_pts = 6 * (rng.rand(n//43) - 0.5)
    all_pts = np.vstack(rows + [rand_pts])
    # 添加一个连续属性 value(可用于 colormap)
    value = np.linalg.norm(all_pts - np.mean(all_pts, axis= 0), axis=1)
    df = pd.DataFrame(all_pts, columns=['x''y''z'])
    df['value'] = value
    # 标注真簇(用于演示)
    labels = np.repeat(np.arange(len(centers)), n//4)
    labels = np.concatenate([labels, np.full(n//4-1)])  # -1 为随机噪声
    df['truth_label'] = labels
    return df

def basic_analysis(df, n_clusters=3):
    """
    一些常用的分析:PCA, KMeans, KDE density (on 3D points).
    返回 dict 包含额外字段
    """

    X = df[['x','y','z']].values
    out = {}
    # PCA 投影,便于制作侧视图或二维投影
    pca = PCA(n_components=3)
    pca_scores = pca.fit_transform(X)
    out['pca'] = pca
    df[['pc1','pc2','pc3']] = pca_scores
    # KMeans 聚类
    kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(X)
    df['kmeans_label'] = kmeans.labels_
    out['kmeans'] = kmeans
    # KDE 估计
    kde = KernelDensity(bandwidth=0.6).fit(X)
    logdens = kde.score_samples(X)
    df['log_density'] = logdens
    out['kde'] = kde
    return df, out

# 查看
if __name__ == "__main__":
    df = generate_synthetic_data()
    df, out = basic_analysis(df)
    print(df.head())

大家把以上函数可直接保存为 common_data.py 并在后续脚本中 import;

也可把代码块复制到每个示例里。

简单说明:

figsize:论文单栏建议宽度 3.4-3.5 inch (~8.6-9 cm),双栏图建议 6.8-7 inch。Matplotlib 以 inch 为单位:figsize=(6.8,6.0) 常用。

DPI 与导出:栅格图使用 300–1200 dpi(高质量打印建议 600 dpi);若含大量点,导出为 TIFF/PNG;若少量图形元素,SVG/PDF 可保真。

点绘参数:marker size、edgecolor、linewidth、alpha、rasterized 需调试:常见起点 s=20-80 (matplotlib scatter 的 s 表示点面积(pt^2))。大量点时 rasterized=True 可减小 SVG 大小。

# 不同案例

我们会给出 5 种样式示例,这些代码都是完整的,大家可以直接使用~

样式 1:Matplotlib,彩色连续映射 + 拟合平面 + 标注

我们用 matplotlib 3D(mpl_toolkits.mplot3d)绘制彩色散点,颜色根据一个连续属性(value 或 log_density),并在三维中拟合一个回归平面(最小二乘),同时显示 colorbar、坐标刻度优化和保存高分辨率图像。

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
from common_data import generate_synthetic_data, basic_analysis
from sklearn.linear_model import LinearRegression

# 生成数据和分析
df = generate_synthetic_data(seed=1, n=1200)
df, out = basic_analysis(df, n_clusters=3)

X = df['x'].values
Y = df['y'].values
Z = df['z'].values
C = df['value'].values 

# 拟合平面 (z = ax + by + c) 作为参考面
A = np.c_[df['x'], df['y'], np.ones(df.shape[0])]
coef, _, _, _ = np.linalg.lstsq(A, df['z'], rcond=None)
a, b, c = coef
# 网格用于绘制平面
xx, yy = np.meshgrid(np.linspace(df['x'].min()-0.5, df['x'].max()+0.520),
                     np.linspace(df['y'].min()-0.5, df['y'].max()+0.520))
zz = a*xx + b*yy + c

# Matplotlib 全局风格优化
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.serif'] = ['Times New Roman']
mpl.rcParams['axes.linewidth'] = 0.8
mpl.rcParams['xtick.direction'] = 'in'
mpl.rcParams['ytick.direction'] = 'in'

fig = plt.figure(figsize=(7,6), dpi=300)
ax = fig.add_subplot(111, projection='3d', proj_type='persp')

# 自定义 colormap(可换成 'viridis' 等)
cmap = mpl.cm.viridis
norm = mpl.colors.Normalize(vmin=C.min(), vmax=C.max())

# 散点
sc = ax.scatter(X, Y, Z,
                c=C, cmap=cmap, norm=norm,
                s=35,                    # marker size (pt^2)
                alpha=0.85,              # 透明度
                edgecolors='k', lw=0.2,  # 黑色轮廓增强对比
                depthshade=True)         # 深度阴影(增加立体感)

# 平面
ax.plot_surface(xx, yy, zz, alpha=0.25, color='gray', rstride=1, cstride=1, linewidth=0, antialiased=True)

# 轴标签、刻度和视角
ax.set_xlabel('X (a.u.)', fontsize=11, labelpad=6)
ax.set_ylabel('Y (a.u.)', fontsize=11, labelpad=6)
ax.set_zlabel('Z (a.u.)', fontsize=11, labelpad=6)
ax.tick_params(axis='both', which='major', labelsize=9)

ax.view_init(elev=20, azim=40)  # 调整至论文效果的相机角度

# Colorbar
cb = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, shrink=0.6, pad=0.02)
cb.set_label('distance from center', fontsize=10)
cb.ax.tick_params(labelsize=9)

# 保存高分辨率图(栅格格式)
plt.tight_layout()
plt.savefig('fig_style1_matplotlib.png', dpi=600, bbox_inches='tight', pad_inches=0.02)
plt.show()
  • s(marker size):35 常用于论文中间密度;若点多可减小到 8-20。
  • edgecolors='k' + lw=0.2:给点加细黑边可以在彩色填充上提高对比度,便于在灰度打印时分辨。
  • alpha:0.7-0.95,过低会丢失色强对比,过高会使重叠区域不明显;若点密集可降低到 0.4-0.6。
  • depthshade=True:为 3D 增强深度感,但在一些 Matplotlib 版本中阴影算法会影响颜色,需检查并调整。
  • 平面透明度 alpha=0.2-0.4:使点仍然突出且平面作为参考。
  • 导出 DPI:600 推荐用于高质量打印;若需要极高分辨率(杂志要求),使用 1200 dpi。

样式 2:Matplotlib,分类型颜色 + 边界 + 投影视图

当数据包含类别信息时,用离散颜色、明显边界和小面板展示主视角 + 两个轴向投影,类似 3D+投影的学术图,便于量化对比。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib as mpl
from common_data import generate_synthetic_data, basic_analysis
from matplotlib.colors import ListedColormap

df = generate_synthetic_data(seed=5, n=1000)
df, out = basic_analysis(df, n_clusters=3)

# 使用聚类结果作为类别
labels = df['kmeans_label'].values
unique_labels = np.unique(labels)
n_labels = unique_labels.size

# 使用色盲友好离散调色板(Tableau)
tableau = mpl.cm.get_cmap('tab10')
colors = [tableau(i) for i in range(n_labels)]
cmap = ListedColormap(colors)

fig = plt.figure(figsize=(9,6), dpi=300)
# 主 3D
ax_main = fig.add_subplot(221, projection='3d')
for lab in unique_labels:
    mask = labels == lab
    ax_main.scatter(df.loc[mask,'x'], df.loc[mask,'y'], df.loc[mask,'z'],
                    label=f'Cluster {lab}', s=30, alpha=0.9,
                    edgecolor='k', linewidth=0.3, color=tableau(lab))
ax_main.set_title('3D scatter colored by cluster', fontsize=11)
ax_main.view_init(18120)

# XY 投影
ax_xy = fig.add_subplot(222)
for lab in unique_labels:
    mask = labels == lab
    ax_xy.scatter(df.loc[mask,'x'], df.loc[mask, 'y'], s=12, alpha=0.9,
                  edgecolor='k', linewidth=0.2, color=tableau(lab))
ax_xy.set_xlabel('X'); ax_xy.set_ylabel('Y')
ax_xy.set_title('XY projection', fontsize=10)

# XZ 投影
ax_xz = fig.add_subplot(223)
for lab in unique_labels:
    mask = labels == lab
    ax_xz.scatter(df.loc[mask,'x'], df.loc[mask,'z'], s=12, alpha=0.9,
                  edgecolor='k', linewidth=0.2, color=tableau(lab))
ax_xz.set_xlabel('X'); ax_xz.set_ylabel('Z')
ax_xz.set_title('XZ projection', fontsize=10)

# YZ 投影
ax_yz = fig.add_subplot(224)
for lab in unique_labels:
    mask = labels == lab
    ax_yz.scatter(df.loc[mask,'y'], df.loc[mask,'z'], s=12, alpha=0.9,
                  edgecolor='k', linewidth=0.2, color=tableau(lab))
ax_yz.set_xlabel('Y'); ax_yz.set_ylabel('Z')
ax_yz.set_title('YZ projection', fontsize=10)

plt.tight_layout()
plt.savefig('fig_style2_panels.png', dpi=600, bbox_inches='tight')
plt.show()

多面板展示能清楚地把 3D 空间信息在 2D 中补充说明,适合论文 figure 的“左图 3D 右图投影”布局。

离散颜色使用 tab10(或 tab20)等,保证类别间对比强烈且在黑白打印时仍可通过点轮廓或不同 marker 区分。

点轮廓在小点上可能引起 Matplotlib 警告,可把 marker='o' 并 edgecolor 设置为较细 linewidth。

样式 3:Plotly,高交互与出版前探索

用 Plotly 绘制交互 3D 散点,添加 hover 信息(如各点的属性),同时展示可自定义 colorscale 与动画,最后把静态高分辨率图导出。

import


    
 plotly.express as px
import plotly.io as pio
from common_data import generate_synthetic_data, basic_analysis

df = generate_synthetic_data(seed=2, n=800)
df, out = basic_analysis(df, n_clusters=3)

fig = px.scatter_3d(df, x='x', y='y', z='z',
                    color='value',                # 连续映射
                    color_continuous_scale='plasma',
                    size_max=6,
                    opacity=0.9,
                    hover_data={'x':True'y':True'z':True'value':True'kmeans_label':True})
# 调整布局
fig.update_layout(scene = dict(
                    xaxis_title='X',
                    yaxis_title='Y',
                    zaxis_title='Z'),
                  font=dict(family='Times New Roman', size=12),
                  coloraxis_colorbar=dict(title='distance'),
                  margin=dict(l=0, r=0, t=30, b=0))

# 可以通过 camera 参数设置初始视角
fig.update_layout(scene_camera=dict(eye=dict(x=1.25, y=1.25, z=0.8)))

# 保存交互 HTML
fig.write_html('fig_style3_plotly_interactive.html')

# 导出高分辨率 PNG
pio.write_image(fig, 'fig_style3_plotly.png', width=1600, height=1200, scale=2)
# 或写 PDF
pio.write_image(fig, 'fig_style3_plotly.pdf', width=1600, height=1200, scale=2)

Plotly 方便在论文投稿前做交互探索(旋转、缩放、筛选)并能导出高分辨率栅格或 PDF。大家在导出时,需要 kaleido 或 Orca。

colorscale 选 plasma/viridis 可保持审美;size_max 控制点在交互中的最大尺寸。

若需动画旋转,可生成若干帧改变 camera.eye 并合成为 gif,但 Plotly 也支持 camera 动画参数。

样式 4:PyVista/Mayavi,推荐用于高质量图与等值面

用 VTK 后端(PyVista 或 Mayavi)做高质量光照、点云渲染和等值面(isosurface)展示,能得到更真实的亮度与阴影,适合展示点云密度或连续场的三维结构。

大家注意,要安装PyVista/Mayavi,用于交互式或保存高质量 PNG。

import numpy as np
import pyvista as pv
from common_data import generate_synthetic_data, basic_analysis
from sklearn.neighbors import KernelDensity

df = generate_synthetic_data(seed=10, n=2000)
df, out = basic_analysis(df, n_clusters=3)

points = df[['x','y','z']].values
logdens = df['log_density'].values

# PyVista 点云
pc = pv.PolyData(points)
pc['density'] = logdens

p = pv.Plotter(off_screen=True)  # off_screen True 可在无 GUI 的服务器上渲染并保存图片
p.add_mesh(pc.glyph(scale=False, geom=pv.Sphere(radius=0.06)), scalars='density', cmap='viridis', show_scalar_bar=True)
p.set_background('white')
p.camera_position = [(8,8,8), (0,0,0), (0,0,1)]
# 添加光照和阴影,使层次更丰富
p.add_light(pv.Light(position=(10,10,10), focal_point=(0,0,0), color='white', intensity=0.8))

# 保存高分辨率图
p.show(screenshot='fig_style4_pyvista.png', window_size=(1600,1200))

PyVista 依赖 VTK,渲染具有真实光照与材质属性,点可以用球体 glyph 渲染。

glyph 的几何体和半径需要根据坐标尺度调整(radius=0.06 只是示例)。大量点使用 glyph 会消耗内存,必要时使用点渲染或降采样。

样式 5:密度/等值面可视化 + 多面比较

除了点云外,我们常希望展示点密度的等值面以揭示结构。

做法:用 KDE 在体素网格上估计密度,再用 marching_cubes提取等值面并渲染,或在 PyVista 直接做 contour。

import numpy as np
import pyvista as pv
from sklearn.neighbors import KernelDensity
from common_data import generate_synthetic_data, basic_analysis

df = generate_synthetic_data(seed=3, n=2000)
df, out = basic_analysis(df, n_clusters=3)

points = df[['x','y','z']].values

# 估计密度到规则网格上
xmin, ymin, zmin = points.min(axis=0) - 0.5
xmax, ymax, zmax = points.max(axis=0) + 0.5

nx = ny = nz = 64

xs = np.linspace(xmin, xmax, nx)
ys = np.linspace(ymin, ymax, ny)
zs = np.linspace(zmin, zmax, nz)
Xg, Yg, Zg = np.meshgrid(xs, ys, zs, indexing='xy')

grid_coords = np.vstack([Xg.ravel(), Yg.ravel(), Zg.ravel()]).T

# KDE
kde = KernelDensity(bandwidth=0.5).fit(points)
logdens_flat = kde.score_samples(grid_coords)

# reshape to (nx, ny, nz)
logdens_grid = logdens_flat.reshape((nx, ny, nz), order='C')

# 创建 pyvista UniformGrid
grid = pv.ImageData()
grid.dimensions = np.array(logdens_grid.shape)
grid.origin = (xmin, ymin, zmin)
grid.spacing = (
    (xmax - xmin) / (nx - 1),
    (ymax - ymin) / (ny - 1),
    (zmax - zmin) / (nz - 1),
)

# 关键:contour 必须使用 point_data
grid.point_data["logdens"] = logdens_grid.ravel(order="F")

# contour levels
iso = [
    logdens_grid.max() - 0.5,
    logdens_grid.max() - 1.0,
    logdens_grid.max() - 1.5,
]

contours = grid.contour(isosurfaces=iso, scalars="logdens")

p = pv.Plotter(off_screen=True)
p.add_mesh(contours, opacity=0.35, cmap="plasma")
p.add_points(points, render_points_as_spheres= True, point_size=5, opacity=0.9)
p.set_background("white")
p.show(screenshot="fig_style5_isosurface.png", window_size=(16001200))

KDE 网格维度 nx,ny,nz 影响等值面精细程度,较细的网格需要更多内存/计算时间。

contour 提取多个 isosurface 值以显示不同密度层次;透明度设置使点云仍能透视,等值面能在视觉上突出高密度结构,非常适合点云聚簇结构说明。

# 总结

上述 5 种样式覆盖了从论文风格静态图、投影对比、多面交互到高级渲染与密度等值面的需求。

大家可以根据稿件格式(单栏/双栏)、期刊对分辨率/文件类型的要求选用导出方式。文章配图应兼顾信息传达准确 和 视觉对比清晰,颜色、边界、透明度和面板布局是实现这一目标的关键参数。


Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/190129