%matplotlib inline
from d2l import mxnet as d2l
from mpl_toolkits import mplot3d
from mxnet import np, npx
npx.set_np()Deep learning loss surfaces are not convex. So why a chapter on convexity? Two reasons:
If a method fails on a convex problem, it has no hope on a non-convex one.
A set \mathcal{X} is convex if for all x, x' \in \mathcal{X} and \lambda \in [0, 1], \lambda x + (1-\lambda) x' \in \mathcal{X} — it contains every line segment between its points.
Nonconvex (left) vs. convex (middle, right).
A function f is convex if for all such x, x', \lambda:
\lambda f(x) + (1-\lambda) f(x') \;\geq\; f(\lambda x + (1-\lambda) x').
The chord lies above the function:
f = lambda x: 0.5 * x**2 # Convex
g = lambda x: d2l.cos(np.pi * x) # Nonconvex
h = lambda x: d2l.exp(0.5 * x) # Convex
x, segment = d2l.arange(-2, 2, 0.01), d2l.tensor([-1.5, 1])
d2l.use_svg_display()
_, axes = d2l.plt.subplots(1, 3, figsize=(9, 3))
for ax, func in zip(axes, [f, g, h]):
d2l.plot([x, segment], [func(x), func(segment)], axes=ax)Intersections of convex sets are convex (useful: feasible sets defined by multiple convex constraints stay convex). Unions are not.
Key property: every local minimum of a convex function is also a global minimum. So convex optimization can’t get stuck in a “wrong” minimum.
Proof sketch: if x^* is local but x' is strictly better, then a point \lambda x^* + (1-\lambda)x' near x^* has f(\cdot) < f(x^*) — contradicting “local minimum”.
For constrained problems \min f(x) s.t. c_i(x) \le 0, three workhorse techniques:
Projecting an external point onto a convex set.