fig, ax1 = plt.subplots(figsize=(10, 8))
shap.summary_plot(
shap_values_numpy,
X_sample,
feature_names=X_sample.columns,
plot_type="dot",
show=False,
color_bar=True,
)
plt.gca().set_position([0.2, 0.2, 0.65, 0.65])
ax1 = plt.gca()
ax2 = ax1.twiny()
shap.summary_plot(shap_values_numpy, X_sample, plot_type="bar", show=False)
plt.gca().set_position([0.2, 0.2, 0.65, 0.65])
bars = ax2.patches
for bar in bars:
bar.set_alpha(0.2)
ax1.set_xlabel("Shapley Value Contribution (Bee Swarm)", fontsize=12)
ax2.set_xlabel("Mean Shapley Value (Feature Importance)", fontsize=12)
ax2.xaxis.set_label_position("top")
ax2.xaxis.tick_top()
ax1.set_ylabel("Features", fontsize=12)
ax2.spines["top"].set_visible(True)
ax2.spines["right"].set_visible(True)
plt.tight_layout()
plt.show()