from d2l import jax as d2l
import jax
from jax import numpy as jnp
import numpy as np
# Sample datapoints and create y coordinate
epsilon = 0.1
key = jax.random.PRNGKey(8675309)
xs = jax.random.normal(key, (300,))
ys = jnp.array(
[jnp.sum(jnp.exp(-(xs[:i] - xs[i])**2 / (2 * epsilon**2))
/ jnp.sqrt(2*jnp.pi*epsilon**2)) / len(xs)
for i in range(len(xs))])
# Compute true density
xd = jnp.arange(jnp.min(xs), jnp.max(xs), 0.01)
yd = jnp.exp(-xd**2/2) / jnp.sqrt(2 * jnp.pi)
# Plot the results
d2l.plot(xd, yd, 'x', 'density')
d2l.plt.scatter(xs, ys)
d2l.plt.axvline(x=0)
d2l.plt.axvline(x=float(jnp.mean(xs)), linestyle='--', color='purple')
d2l.plt.title(f'sample mean: {float(jnp.mean(xs)):.2f}')
d2l.plt.show()