Discretization Drift in Two-Player Games
In this work, we quantify the discretisation error induced by gradient descent in two-player games, and use that to understand and improve such games, including Generative Adversarial Networks.
Two-player games
Many machine learning applications involve not one single model, but two models which get trained jointly. The two models can minimise the same loss (common-payoff games), or completely adversarial losses (zero-sum games) or anything in between. Famous examples of two-player games include Generative Adversarial Networks1 and model based reinforcement learning2.
Discretization drift
Training two-player games with player parameters $\phi$ and $\theta$ to minimise their given losses $L_1( \phi, \theta)$ and $L_2( \phi, \theta)$ requires discretizing the underlying ODEs which define the game:
$$\dot{\phi} = -\nabla_{\phi} L_1( \phi, \theta) $$ $$ \dot{\theta} = - \nabla_{\theta} L_2( \phi, \theta)$$
Discretising the above ODEs using Euler discretization with learning rates $\alpha h$ and $\lambda h$ can lead to the familiar simultaneous gradient descent: $$\phi_t = \phi_{t-1} - \alpha h \nabla_{\phi} L_1( \phi_{t-1}, \theta_{t-1}) $$ $$ \theta_t = \theta_{t-1} - \lambda h \nabla_{\theta} L_2( \phi_{t-1}, \theta_{t-1}) $$
Similarly, we can update the first player, and use the updated value when updating the second player, leading to alternating gradient descent:
$$\phi_t = \phi_{t-1} - \alpha h \nabla_{\phi} L_1( \phi_{t-1}, \theta_{t-1}) $$ $$ \theta_t = \theta_{t-1} - \lambda h \nabla_{\theta} L_2( \phi_t, \theta_{t-1}) $$
We can generalise the above ODEs beyond differentiable two-player games, and write:
$$\dot{\phi} = f( \phi, \theta) $$ $$\dot{\theta} = g( \phi, \theta)$$
Using Euler discretization on the above ODEs leads to simultaneous Euler updates:
$$\phi_t = \phi_{t-1} + \alpha h f( \phi_{t-1}, \theta_{t-1})$$ $$\theta_t = \theta_{t-1} + \lambda h g( \phi_{t-1}, \theta_{t-1}) $$
and the alternating Euler updates:
$$\phi_t = \phi_{t-1} + \alpha h f( \phi_{t-1}, \theta_{t-1}) $$ $$\theta_t = \theta_{t-1} + \lambda h g( \phi_t, \theta_{t-1}) $$
For $f = -\nabla_{\phi} L_1$ and $g = - \nabla_{\theta} L_2$, we recover the differentiable two-player setting above.
While Euler discretization is a powerful numerical integration method which allows minimising functions efficiently, it can introduce errors which make it deviate from the original ODE:
Discretization Drift: The discrete updates follow different trajectories compared to the continuous flow given by the original ODE.
Backward error analysis
In order to better understand gradient descent and Euler integration more broadly, we use backward error analysis to construct a set of modified ODEs which better describe the trajectory taken by the discrete updates:
Illustration: For each player we find the modified ODE which captures the change in parameters introduced by the discrete updates with an error of $\mathcal{O}(h^3)$, where $h$ is the learning rate; the original ODE has an error of $\mathcal{O}(h^2)$.
This approach was introduced in supervised learning by 3. Using this approach, we find the following modified ODEs which better describe the discrete Euler dynamics by reducing the discretization error from $\mathcal{O}(h^2)$ to $\mathcal{O}(h^3)$.
For simultaneous updates:
$$\dot{\tilde{\phi}} = f - \frac{\alpha h}{2} \left(f \nabla_{\phi} f + g \nabla_{\theta}f\right)$$ $$\dot{\tilde{\theta}} = g - \frac{\lambda h}{2} \left(g \nabla_{\theta} g + f\nabla_{\phi} g \right)$$
For alternating updates:
$$\dot{\tilde{\phi}} = f - \frac{\alpha h}{2} \left(f \nabla_{\phi} f + g \nabla_{\theta}f\right)$$ $$\dot{\tilde{\theta}} = g - \frac{\lambda h}{2} \left(g \nabla_{\theta} g + (1 - \frac{2 \alpha}{\lambda}) f\nabla_{\phi} g \right)$$
Definition: The discretization drift for each player has two terms: one term containing a player’s own update function only - terms we will call self terms - and a term that also contains the other player’s update function - which we will call interaction terms.
These modified ODEs allow us to better understand the behaviour of simultaneous and alternating Euler updates, the difference between them, and use stability analysis to understand the behaviour of these updates around an equilibrium.
Illustration: The modified ODEs follow the discrete updates much closer than the original updates.
Gradient descent in zero-sum games and GANs
We can specialise the general results above to zero-sum games, where $L_1 = E$ and $L_2 = -E$, and thus $f = \nabla_{\phi} E$ and $g = - \nabla_{\theta} E$.
For simultaneous updates we obtain the following modified ODEs:
$$\dot{\tilde{\phi}} = - \nabla_{\phi} \left(E + \frac{\alpha h}{2} (|| \nabla_{\phi} E||^2 - || \nabla_{\theta} E||^2) \right) $$ $$\dot{\tilde{\theta}} = - \nabla_{\theta} \left(-E + \frac{\alpha h}{2} (|| \nabla_{\theta} E||^2 - || \nabla_{\phi} E||^2) \right)$$
Since the right hand side of these ODEs is a negative gradient, we can write the modified loss functions induced by the modified ODEs above which better describe what gradient descent is minimising in zero-sum games:
$$\tilde L_{1} = E + \frac{\alpha h}{4} ||{\nabla_{\phi} E}||^2 - \frac{\alpha h}{4} ||{\nabla_{\theta} E}||^2 $$ $$\tilde L_{2} = - E + \frac{\lambda h}{4} ||{\nabla_{\theta} E}||^2 - \frac{\lambda h}{4} ||{\nabla_{\phi} E}||^2 $$
The modified losses above allow us to get the insight that the interaction terms ($- \frac{\alpha h}{4} ||{\nabla_{\theta} E}||^2$ for the first player and $- \frac{\lambda h}{4} ||{\nabla_{\phi} E}||^2$) maximse the norm of the other player which can cause training instability and decrease performance. To test this empirically, we compare zero-sum GANs trained with gradient descent using the original losses of the game ($L_1 = E$ and $L_2 = -E$) to those which use explicit regularisation to cancel the interaction terms ($L_1 = E + \frac{\alpha h}{4} ||{\nabla_{\theta} E}||^2 $ and $L_2 = -E + \frac{\lambda h}{4} ||{\nabla_{\phi} E}||^2$). We see that our intuition holds in practice, as the explicit regularisation results perform much better, and obtain the same peak performance as Adam4:
Illustration: Simultaneous gradient descent: Explicit regularization canceling the terms which maximise the gradient norm of the other player (interaction terms) improves performance. Higher is better.
We can perform the same exercise for alternating updates, leading to the modified losses: $$\tilde L_{1} = E + \frac{\alpha h}{4} ||{\nabla_{\phi} E}||^2 - \frac{\alpha h}{4} ||{\nabla_{\theta} E}||^2 $$ $$\tilde L_{2} = - E + \frac{\lambda h}{4} ||{\nabla_{\theta} E}||^2 - \frac{\lambda h}{4} (1 - \frac{2 \alpha}{\lambda}) ||{\nabla_{\phi} E}||^2 $$
Based on the modified losses above, and the insights obtained from simultaneous updates, we can predict that learning rate ratios for which $(1 - \frac{2 \alpha}{\lambda}) < 0$ will perform best, since the sign of the interaction term of the second player will be positive, rather than negative as it was for simultaneous updates. We can test this empirically in the GAN setting again, and confirm this intuition:
Illustration: Alternating gradient descent performs better for learning rate ratios which reduce the adversarial nature of discretization drift. The same learning rate ratios show no advantage in the simultaneous case. Higher is better.
Summary
We have used backward error analysis to derive the modified ODEs which better describe gradient descent in two-player games. These modified ODEs have provided insight into the difference between simultaneous and alternating updates, as well as the instabilities behind specific types of games trained using gradient descent, such as zero-sum games; we then used that insight to improve the stability and performance of GAN training.
References
Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A. and Bengio, Y. Generative adversarial nets. Advances in Neural Information Processing Systems, 2014. ↩︎
Sutton, Richard S., and Andrew G. Barto. Reinforcement learning: An introduction. MIT press, 2018 ↩︎
Barrett David, Dherin Benoit. Implicit Gradient Regularization. International Conference on Learning Representations 2020. ↩︎
Kingma, D.P. and Ba, J. Adam: A Method for Stochastic Optimization. International Conference on Learning Representations 2015. ↩︎