如何使用一组包含多条线和数据点的go.Figure对象创建子图呢?解释如下:
# 数据可视化from plotly.subplots import make_subplotsimport plotly.graph_objects as goepoch_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]loss_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]val_loss_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]error_rate = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]val_error_rate = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]loss_plots = [go.Scatter(x=epoch_list, y=loss_list, mode='lines', name='Loss', line=dict(width=4)), go.Scatter(x=epoch_list, y=val_loss_list, mode='lines', name='Validation Loss', line=dict(width=4))]loss_figure = go.Figure(data=loss_plots)error_plots = [go.Scatter(x=epoch_list, y=loss_list, mode='lines', name='Error Rate', line=dict(width=4)), go.Scatter(x=epoch_list, y=val_loss_list, mode='lines', name='Validation Error Rate', line=dict(width=4))]error_figure = go.Figure(data=error_plots)metric_figure = make_subplots( rows=3, cols=2, specs=[[{}, {}], [{}, {}], [{'colspan': 2}, {}]])metric_figure.append_trace(loss_figure, row=1, col=1)metric_figure.append_trace(error_figure, row=1, col=2)metric_figure.show()
在尝试创建子图时,我遇到的错误是“invalid element(s) received for the ‘data’ property of Invalid elements include: [Figure”。我认为我知道错误发生的原因,但有没有解决的办法?我仍然希望更改每个图的布局,并在单个图上显示多条线。
回答:
这是一个简单的解决方案,通过遍历每个图中的轨迹并将它们添加到所需的子图中。
# 数据可视化from plotly.subplots import make_subplotsimport plotly.graph_objects as goepoch_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]loss_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]val_loss_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]error_rate = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]val_error_rate = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]loss_plots = [go.Scatter(x=epoch_list, y=loss_list, mode='lines', name='Loss', line=dict(width=4)), go.Scatter(x=epoch_list, y=val_loss_list, mode='lines', name='Validation Loss', line=dict(width=4))]loss_figure = go.Figure(data=loss_plots)error_plots = [go.Scatter(x=epoch_list, y=loss_list, mode='lines', name='Error Rate', line=dict(width=4)), go.Scatter(x=epoch_list, y=val_loss_list, mode='lines', name='Validation Error Rate', line=dict(width=4))]error_figure = go.Figure(data=error_plots)metric_figure = make_subplots( rows=3, cols=2, specs=[[{}, {}], [{}, {}], [{'colspan': 2}, {}]])for t in loss_figure.data: metric_figure.append_trace(t, row=1, col=1)for t in error_figure.data: metric_figure.append_trace(t, row=1, col=2)metric_figure.show()