%matplotlib inline
from d2l import jax as d2l
from IPython import display
import jax
from jax import numpy as jnp
import numpy as np
# Plot a function in a normal range
x_big = jnp.arange(0.01, 3.01, 0.01)
ys = jnp.sin(x_big**x_big)
d2l.plot(x_big, ys, 'x', 'f(x)')