import matplotlib.pyplot as plt
import numpy as np
plt.rc('font', family='serif', size=15)
plt.rc('text', usetex=0)
dt = 1e-2
t = np.arange(0, .5*np.pi, dt)
va = 1
vs = 1.25
vavs = (va**2) + (vs**2)
v_fast = np.sqrt(0.5*vavs + 0.5 * np.sqrt(vavs*vavs - 4*(va**2)*(vs**2)*((np.cos(t))**2)))
v_shear = np.sqrt(0.5*vavs - 0.5 * np.sqrt(vavs*vavs - 4*(va**2)*(vs**2)*((np.cos(t))**2)))
v_slow = va * np.cos(t)
fig = plt.figure(figsize=(6,4))
ax = fig.add_subplot(111)
ax.plot(t, v_fast, label='$v_{fast}$')
ax.plot(t, v_slow, label='$v_{slow}$')
ax.plot(t, v_shear, label='$v_{shear}$')
ax.legend(loc='right')
ax.set_xticks([0,np.pi*.25,np.pi*.5])
ax.set_xticklabels(['$0$\n$\mathbf{k}\parallel\mathbf{B}_0$', '$\\frac{\\pi}{4}$', '$\\frac{\\pi}{2}$\n$\mathbf{k}\perp\mathbf{B}_0$'])
ax.set_xlim([0,np.pi/2])
ax.set_ylim([0,np.sqrt(vavs)])
ax.set_yticks([va, vs, np.sqrt(vavs)])
ax.set_yticklabels(['$v_A$', '$v_s$', '$\\sqrt{v_A^2+v_s^2}$'])
ax.set_xlabel('$\\theta$', fontdict={'size':20})
ax.set_ylabel('$\\frac{\\omega}{k}$', rotation=0, fontdict={'size':20})
ax.set_title('Phase velocity $v_{ph}=\\frac{\\omega}{k}$ with respect to $\\theta$, $v_A<v_s$')
fig.savefig('mhdvph.svg', bbox_inches='tight')