<?xml version="1.0" encoding="utf-8" standalone="yes"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom">
  <channel>
    <title>Categories on Carl Fredriksson</title>
    <link>https://cfml.se/categories/</link>
    <description>Recent content in Categories on Carl Fredriksson</description>
    <generator>Hugo -- gohugo.io</generator>
    <language>en-US</language>
    
	<atom:link href="https://cfml.se/categories/index.xml" rel="self" type="application/rss+xml" />
    
    
    
    <item>
      <title>Causal Inference Basics</title>
      <link>https://cfml.se/causal-inference-basics/</link>
      <pubDate>Tue, 30 Jan 2024 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/causal-inference-basics/</guid>
      <description>&lt;iframe
    src=&#34;https://cfml.se/notebooks/causal_inference_basics.html&#34;
    width=&#34;100%&#34;
    scrolling=&#34;no&#34;
    onload=&#34;this.style.height=(this.contentDocument.body.scrollHeight + 150) + &#39;px&#39;;&#34;
    style=&#34;border:none;&#34;&gt;
&lt;/iframe&gt;
</description>
    </item>
    
    
    
    <item>
      <title>RL Gymnasium Training</title>
      <link>https://cfml.se/rl-gymnasium-training/</link>
      <pubDate>Thu, 09 Nov 2023 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/rl-gymnasium-training/</guid>
      <description>&lt;p&gt;After my second, significantly more thorough reading of &lt;a href=&#34;http://incompleteideas.net/book/the-book.html&#34;&gt;Reinforcement Learning: An Introduction by Sutton and Barto (2nd edition)&lt;/a&gt;, I wanted to get my hands dirty and apply some of what I&amp;rsquo;d learned. I created the repository &lt;a href=&#34;https://github.com/CarlFredriksson/rl-gymnasium-training&#34;&gt;rl-gymnasium-training&lt;/a&gt; to work on my Reinforcement Learning (RL) skills using Gymnasium environments (&lt;a href=&#34;https://gymnasium.farama.org/&#34;&gt;https://gymnasium.farama.org/&lt;/a&gt;). There is no substitute for trying to implement and apply algorithms yourself - In my experience it not only improves your practical skills in a field, but also improves your understanding of the theory. At the time of writing this I had solved three simple gym environments, below are some gif examples:&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th style=&#34;text-align:center&#34;&gt;Random Agent&lt;/th&gt;
&lt;th style=&#34;text-align:center&#34;&gt;Trained Agent&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td style=&#34;text-align:center&#34;&gt;&lt;img src=&#34;https://cfml.se/images/rl_gymnasium_training/cart_pole_random_agent.gif&#34; alt=&#34;CartPole Random Agent&#34;&gt;&lt;/td&gt;
&lt;td style=&#34;text-align:center&#34;&gt;&lt;img src=&#34;https://cfml.se/images/rl_gymnasium_training/cart_pole_trained_agent.gif&#34; alt=&#34;CartPole Trained Agent&#34;&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style=&#34;text-align:center&#34;&gt;&lt;img src=&#34;https://cfml.se/images/rl_gymnasium_training/lunar_lander_random_agent.gif&#34; alt=&#34;Lunar Lander Random Agent&#34;&gt;&lt;/td&gt;
&lt;td style=&#34;text-align:center&#34;&gt;&lt;img src=&#34;https://cfml.se/images/rl_gymnasium_training/lunar_lander_trained_agent.gif&#34; alt=&#34;Lunar Lander Trained Agent&#34;&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style=&#34;text-align:center&#34;&gt;&lt;img src=&#34;https://cfml.se/images/rl_gymnasium_training/frozen_lake_random_agent.gif&#34; alt=&#34;FrozenLake Random Agent&#34;&gt;&lt;/td&gt;
&lt;td style=&#34;text-align:center&#34;&gt;&lt;img src=&#34;https://cfml.se/images/rl_gymnasium_training/frozen_lake_trained_agent.gif&#34; alt=&#34;CartPole Trained Agent&#34;&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;Check out the repository if you&amp;rsquo;re interested, and feel free to reach out with any questions!&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Bellman Operators are Contractions</title>
      <link>https://cfml.se/bellman-operators-are-contractions/</link>
      <pubDate>Mon, 07 Nov 2022 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/bellman-operators-are-contractions/</guid>
      <description>&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;I&amp;rsquo;m studying Reinforcement Learning (RL) again, primarily using the books:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;a href=&#34;http://incompleteideas.net/book/the-book.html&#34;&gt;Reinforcement Learning: An Introduction by Sutton and Barto (2nd edition)&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href=&#34;https://github.com/MathFoundationRL/Book-Mathmatical-Foundation-of-Reinforcement-Learning&#34;&gt;Mathematical Foundations of Reinforcement Learning by Zhao&lt;/a&gt;.&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;I&amp;rsquo;ve read book 1 once before, often referred to as the RL textbook or even the RL bible, but I had forgotten most of it and wanted to study the book in a slower and more careful manner, with the aim of finishing it with a deeper understanding than I did the first time. At the time of writing this I&amp;rsquo;m still in the early parts of the book. What prompted this post was to document what I learned while trying to answer the following questions that I struggled with for a while:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;How do we know that the Bellman equations and Bellman optimality equations have unique solutions?&lt;/li&gt;
&lt;li&gt;Why can&amp;rsquo;t iterative policy evaluation or policy/value iteration get stuck in local optima?&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;To answer these questions I had to improve my mathematical foundations, and I&amp;rsquo;m grateful to have found book 2 which has been a big help along with various other online resources. With this post I hope to preserve my knowledge and to help anyone with similar questions that happen to find it. If you have any questions or feedback, please feel free to reach out to me at &lt;a href=&#34;mailto:c@cfml.se&#34;&gt;c@cfml.se&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;contraction-mappings-and-the-banach-fixed-point-theorem&#34;&gt;Contraction mappings and the Banach fixed-point theorem&lt;/h2&gt;
&lt;p&gt;From the Wikipedia article &lt;a href=&#34;https://en.wikipedia.org/wiki/Banach_fixed-point_theorem&#34;&gt;Banach fixed-point theorem&lt;/a&gt;:&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;Definition. Let $(X,d)$ be a &lt;a href=&#34;https://en.wikipedia.org/wiki/Complete_metric_space&#34;&gt;complete metric space&lt;/a&gt;. Then a map $T:X \to X$ is called a &lt;a href=&#34;https://en.wikipedia.org/wiki/Contraction_mapping&#34;&gt;contraction mapping&lt;/a&gt; on $X$ if there exists $q \in [0,1)$ such that&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;$$
d\big(T(x),T(y)\big) \leq qd(x,y)
$$&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;for all $x,y \in X$.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;Banach Fixed Point Theorem&lt;/strong&gt;. Let $(X,d)$ be a &lt;a href=&#34;https://en.wikipedia.org/wiki/Empty_set&#34;&gt;non-empty&lt;/a&gt; complete metric space with a contraction mapping $T:X\to X$. Then $T$ admits a unique &lt;a href=&#34;https://en.wikipedia.org/wiki/Fixed_point_(mathematics)&#34;&gt;fixed-point&lt;/a&gt; $x^*$ in $X$ (i.e. $T(x^*) = x^*$). Furthermore, $x^*$ can be found as follows: start with an arbitrary element $x_0 \in X$ and define a sequence $(x_n)_{n \in \mathbb{N}}$ by $x_n = T(x_{n-1})$ for $n \geq 1$. Then $\lim_{n \to \infty} x_n = x^*$.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;Thus, we need to prove that the right hand side of the Bellman equations and Bellman optimality equations are contractions to answer both questions stated in the introduction.&lt;/p&gt;
&lt;p&gt;You can see a proof for the theorem in the Wikipedia article &lt;a href=&#34;https://en.wikipedia.org/wiki/Banach_fixed-point_theorem&#34;&gt;Banach fixed-point theorem&lt;/a&gt;. The rough idea is that a sequence $(x_n)_{n \in \mathbb{N}}$ by $x_n = T(x_{n-1})$ for $n \geq 1$ is shown to be a &lt;a href=&#34;https://en.wikipedia.org/wiki/Cauchy_sequence&#34;&gt;Cauchy sequence&lt;/a&gt; (a sequence whose elements become arbitrarily close to each other as the sequence progresses). Thus, the sequence converges to $x^* \in X$ which has to be a fixed point of $T$. It has to be the only fixed point of $T$ since any pair of distinct fixed points $p_1$ and $p_2$ would contradict $T$ being a contraction:&lt;/p&gt;
&lt;p&gt;$$
d\big(T(p_1), T(p_2)\big) = d(p_1, p_2) &amp;gt; q d(p_1, p_2)
$$&lt;/p&gt;
&lt;h2 id=&#34;bellman-equations&#34;&gt;Bellman equations&lt;/h2&gt;
&lt;p&gt;Let&amp;rsquo;s consider any finite &lt;a href=&#34;https://en.wikipedia.org/wiki/Markov_decision_process&#34;&gt;Markov decision process&lt;/a&gt; (MDP) where $\mathcal{S}$ denote the finite set of all states, $\mathcal{R} \subset \mathbb{R}$ the finite set of all rewards, and $\mathcal{A}(s)$ the finite set of all available actions $a$ in state $s$. Let $r_\pi(s)$ denote the expected immediate reward and $p_\pi(s^\prime | s)$ the probability of transitioning to state $s^\prime$ when following policy $\pi$ from state $s$. Let $\gamma \in (0, 1)$ denote the discount rate for future rewards. I&amp;rsquo;m assuming some knowledge about MDPs and won&amp;rsquo;t explain all of the notation here, feel free to check my older post &lt;a href=&#34;https://cfml.se/markov-decision-processes/&#34;&gt;Markov Decision Processes&lt;/a&gt; for an introduction.&lt;/p&gt;
&lt;p&gt;The Bellman equation for the state-value function for policy $\pi$ can be defined as follows:&lt;/p&gt;
&lt;div&gt;
$$
\begin{aligned}
v_{\pi}(s) &amp;= \mathbb{E}_\pi \big[R_{t+1} + \gamma v_{\pi}(S_{t+1}) | S_t = s \big] \\
&amp;= \sum_{a \in \mathcal{A}(s)} \pi(a | s) \bigg[\sum_{r \in \mathcal{R}} p(r | s, a) r + \gamma \sum_{s^\prime \in \mathcal{S}} p(s^\prime | s, a) v_{\pi}(s^\prime) \bigg] \\
&amp;= r_\pi(s) + \gamma \sum_{s^\prime \in \mathcal{S}} p_\pi(s^\prime | s) v_{\pi}(s^\prime)
\end{aligned}
$$
&lt;/div&gt;
&lt;p&gt;for all $s \in \mathcal{S}$. Let $n = |\mathcal{S}|$, we can write the equation in matrix form:&lt;/p&gt;
&lt;div&gt;
$$
\begin{bmatrix}
v_\pi(1) \\
\vdots \\
v_\pi(n)
\end{bmatrix}=
\begin{bmatrix}
r_\pi(1) \\
\vdots \\
r_\pi(n)
\end{bmatrix}
+\gamma
\begin{bmatrix}
p_\pi(1 | 1) &amp; \dots &amp; p_\pi(n | 1) \\
\vdots &amp; \ddots &amp; \vdots\\
p_\pi(1 | n) &amp; \dots &amp; p_\pi(n | n)
\end{bmatrix}
\begin{bmatrix}
v_\pi(1) \\
\vdots \\
v_\pi(n)
\end{bmatrix}
$$
&lt;/div&gt;
&lt;p&gt;More compactly:&lt;/p&gt;
&lt;p&gt;$$
v_\pi = r_\pi + \gamma P_\pi v_\pi
$$&lt;/p&gt;
&lt;p&gt;We can define an (expected) Bellman operator $\mathcal{T}^\pi : \mathbb{R}^n \to \mathbb{R}^n$ as:&lt;/p&gt;
&lt;p&gt;$$
\mathcal{T}^\pi(v) = r_\pi + \gamma P_\pi v
$$&lt;/p&gt;
&lt;p&gt;for any $v \in \mathbb{R}^n$.&lt;/p&gt;
&lt;p&gt;Let $||\cdot||$ be a norm in $\mathbb{R}^n$. If there exists a $\gamma \in (0, 1)$ such that $||\mathcal{T}^\pi(v_1) - \mathcal{T}^\pi(v_2)|| \leq \gamma ||v_1 - v_2||$ for all $v_1, v_2 \in \mathbb{R}^n$, then $\mathcal{T}^\pi$ is a contraction mapping. Note that the definition of a contraction uses a distance function $d$ for generality (&lt;a href=&#34;https://math.stackexchange.com/questions/172028/difference-between-norm-and-distance&#34;&gt;all norms can be used to create a distance function as in $d(x, y) = ||x - y||$, but not all distance functions have a corresponding norm&lt;/a&gt;). In all proofs in this post $|\cdot|$ and $\leq$ are elementwise, and the norm used is the max norm $||x||_\infty = \max(|x|) = \max(|x_1|, \dots, |x_n|)$, where $\max(\cdot) : \mathbb{R}^n \to \mathbb{R}$ chooses the largest element in a vector.&lt;/p&gt;
&lt;div&gt;
$$
\begin{aligned}
||\mathcal{T}^\pi(v_1) - \mathcal{T}^\pi(v_2)||_\infty &amp;= \max \big(|r_\pi + \gamma P_\pi v_1 - (r_\pi + \gamma P_\pi v_2)| \big) \\
&amp;= \gamma \max \big(|P_\pi(v_1 - v_2)| \big) \\
&amp;\leq \gamma \max \big(P_\pi|v_1 - v_2| \big) \\
&amp;\leq \gamma \max \big(|v_1 - v_2| \big) \\
&amp;= \gamma ||v_1 - v_2||_\infty
\end{aligned}
$$
&lt;/div&gt;
&lt;p&gt;Thus $\mathcal{T}^\pi$ is a contraction. The last inequality is due to the rows of $P_\pi$ containing only non-negative elements that sum to 1.&lt;/p&gt;
&lt;p&gt;Let $r(s, a)$ denote the expected immediate reward when selecting action $a$ in state $s$, and $p_\pi(s^\prime, a^\prime | s, a)$ the probability of transitioning to state $s^\prime$ and selecting action $a^\prime$ when selecting action $a$ in state $s$ and following policy $\pi$ after. The Bellman equation for the action-value function for policy $\pi$ can be defined as follows:&lt;/p&gt;
&lt;div&gt;
$$
\begin{aligned}
q_{\pi}(s, a) &amp;= \mathbb{E}_\pi \big[R_{t+1} + \gamma q_{\pi}(S_{t+1}, A_{t+1}) | S_t = s, A_t = a \big] \\
&amp;=  \sum_{r \in \mathcal{R}} p(r | s, a) r + \gamma \sum_{s^\prime \in \mathcal{S}} p(s^\prime | s, a) \sum_{a^\prime \in \mathcal{A}(s^\prime)} \pi(a^\prime | s^\prime) q_{\pi}(s^\prime, a^\prime) \\
&amp;= r(s, a) + \gamma \sum_{s^\prime \in \mathcal{S}} \sum_{a^\prime \in \mathcal{A}(s^\prime)} p_\pi(s^\prime, a^\prime | s, a) q_{\pi}(s^\prime, a^\prime)
\end{aligned}
$$
&lt;/div&gt;
&lt;p&gt;for all $s \in \mathcal{S}$, $a \in \mathcal{A}(s)$. Let $n = |\mathcal{S}|$ and $n_s = |\mathcal{A}(s)|$, we can write the equation in matrix form:&lt;/p&gt;
&lt;div&gt;
$$
\begin{bmatrix}
q_\pi(1, 1) \\
q_\pi(1, 2) \\
\vdots \\
q_\pi(n, n_n)
\end{bmatrix}=
\begin{bmatrix}
r_\pi(1, 1) \\
r_\pi(1, 2) \\
\vdots \\
r_\pi(n, n_n)
\end{bmatrix}
+\gamma
\begin{bmatrix}
p_\pi(1, 1 | 1, 1) &amp; p_\pi(1, 2 | 1, 1) &amp; \dots &amp; p_\pi(n, n_n | 1, 1) \\
p_\pi(1, 1 | 1, 2) &amp; p_\pi(1, 2 | 1, 2) &amp; \dots &amp; p_\pi(n, n_n | 1, 2) \\
\vdots &amp; \vdots &amp; \ddots &amp; \vdots \\
p_\pi(1, 1 | n, n_n) &amp; p_\pi(1, 2 | n, n_n) &amp; \dots &amp; p_\pi(n, n_n | n, n_n)
\end{bmatrix}
\begin{bmatrix}
q_\pi(1, 1) \\
q_\pi(1, 2) \\
\vdots \\
q_\pi(n, n_n)
\end{bmatrix}
$$
&lt;/div&gt;
&lt;p&gt;More compactly:&lt;/p&gt;
&lt;p&gt;$$
q_\pi = r_\pi + \gamma P_\pi q_\pi
$$&lt;/p&gt;
&lt;p&gt;The only difference compared to the equation for the state-value function is that the vectors and matrices are larger. Thus, we can define an expected Bellman operator in identical fashion (with $n$ denoting the number of state-action pairs rather than the number of states) and the proof will be identical (the rows of $P_\pi$ still contain only non-negative elements that sum to 1).&lt;/p&gt;
&lt;h2 id=&#34;bellman-optimality-equations&#34;&gt;Bellman optimality equations&lt;/h2&gt;
&lt;p&gt;For brevity I will go directly into the compact matrix form of the Bellman optimality equation for the optimal state-value function:&lt;/p&gt;
&lt;p&gt;$$
v_* = \max_\pi(r_\pi + \gamma P_\pi v_*)
$$&lt;/p&gt;
&lt;p&gt;We can define an (expected) Bellman optimality operator $\mathcal{T}^* : \mathbb{R}^n \to \mathbb{R}^n$ as:&lt;/p&gt;
&lt;p&gt;$$
\mathcal{T}^*(v) = \max_\pi(r_\pi + \gamma P_\pi v)
$$&lt;/p&gt;
&lt;p&gt;Consider any two vectors $v_1, v_2 \in \mathbb{R}^n$, and let $\pi_1^* = \text{argmax}_\pi(r_\pi + \gamma P_\pi v_1)$ and $\pi_2^* = \text{argmax}_\pi(r_\pi + \gamma P_\pi v_2)$. Then we have:&lt;/p&gt;
&lt;div&gt;
$$
\mathcal{T}^*(v_1) = \max_\pi(r_\pi + \gamma P_\pi v_1) = r_{\pi_1^*} + \gamma P_{\pi_1^*} v_1 \geq r_{\pi_2^*} + \gamma P_{\pi_2^*} v_1 \\
\mathcal{T}^*(v_2) = \max_\pi(r_\pi + \gamma P_\pi v_2) = r_{\pi_2^*} + \gamma P_{\pi_2^*} v_2 \geq r_{\pi_1^*} + \gamma P_{\pi_1^*} v_2
$$
&lt;/div&gt;
&lt;p&gt;and&lt;/p&gt;
&lt;div&gt;
$$
\begin{aligned}
\mathcal{T}^*(v_1) - \mathcal{T}^*(v_2) &amp;= r_{\pi_1^*} + \gamma P_{\pi_1^*} v_1 - (r_{\pi_2^*} + \gamma P_{\pi_2^*} v_2) \\
&amp;\leq r_{\pi_1^*} + \gamma P_{\pi_1^*} v_1 - (r_{\pi_1^*} + \gamma P_{\pi_1^*} v_2) \\
&amp;= \gamma P_{\pi_1^*} (v_1 - v_2)
\end{aligned}
$$
&lt;/div&gt;
&lt;p&gt;Similarly we have $\mathcal{T}^*(v_2) - \mathcal{T}^*(v_1) \leq \gamma P_{\pi_2^*} (v_2 - v_1)$, which implies that $\mathcal{T}^*(v_1) - \mathcal{T}^*(v_2) \geq \gamma P_{\pi_2^*} (v_1 - v_2)$, and thus:&lt;/p&gt;
&lt;div&gt;
$$
\gamma P_{\pi_2^*} (v_1 - v_2) \leq \mathcal{T}^*(v_1) - \mathcal{T}^*(v_2) \leq \gamma P_{\pi_1^*} (v_1 - v_2)
$$
&lt;/div&gt;
&lt;p&gt;Let $\max\{\cdot, \cdot \} : \mathbb{R}^n \times \mathbb{R}^n \to \mathbb{R}^n$ choose the largest values between two vectors elementwise, we have:&lt;/p&gt;
&lt;div&gt;
$$
\begin{aligned}
|\mathcal{T}^*(v_1) - \mathcal{T}^*(v_2)| &amp;\leq \max \{|\gamma P_{\pi_2^*} (v_1 - v_2)|, |\gamma P_{\pi_1^*} (v_1 - v_2)| \} \\
&amp;\leq \gamma \max \{P_{\pi_2^*} |v_1 - v_2|, P_{\pi_1^*} |v_1 - v_2| \}
\end{aligned}
$$
&lt;/div&gt;
&lt;p&gt;Let $\max(\cdot) : \mathbb{R}^n \to \mathbb{R}$ choose the largest element in a vector (as previously defined), we have:&lt;/p&gt;
&lt;div&gt;
$$
\begin{aligned}
||\mathcal{T}^*(v_1) - \mathcal{T}^*(v_2)||_\infty &amp;= \max \big(|\mathcal{T}^*(v_1) - \mathcal{T}^*(v_2)| \big) \\
&amp;\leq \gamma \max \big(\max \{P_{\pi_2^*} |v_1 - v_2|, P_{\pi_1^*} |v_1 - v_2| \} \big) \\
&amp;\leq \gamma \max \big(|v_1 - v_2| \big) \\
&amp;= \gamma ||v_1 - v_2||_\infty
\end{aligned}
$$
&lt;/div&gt;
&lt;p&gt;Thus $\mathcal{T}^*$ is a contraction. Note that the last inequality is due to the rows of $P_{\pi_1^*}$ and $P_{\pi_2^*}$ containing only non-negative elements that sum to 1. As for Bellman equations, an almost identical proof can be written for action-values with the only difference being that we consider state-action pairs rather than states.&lt;/p&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;We&amp;rsquo;re now ready to answer the questions from the introduction.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;How do we know that the Bellman equations and Bellman optimality equations have unique solutions?&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;We can define the right hand side of the equations as operators and show that they are contractions, which means that they have unique fixed-points that can be solved for by starting with an arbitrary point and iteratively applying the operator, due to the Banach fixed-point theorem. Any solution to one of the equations must be a fixed-point for its corresponding operator, thus these fixed-points are the unique solutions to the equations.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;Why can&amp;rsquo;t iterative policy evaluation or policy/value iteration get stuck in local optima?&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;Iterative policy evaluation and value iteration are simply iteratively applying the Bellman operator or Bellman optimality operator respectively. As explained above we know that the operators have unique fixed-points that can be solved for in this manner. In other words, there are no local optima. However, note that convergence is only guaranteed in the limit and one might have to settle for approximations, depending on the complexity of the MDP. Policy iteration is not as straight forward, but it&amp;rsquo;s closely connected to value iteration and I find it quite intuitive that once we know that value iteration will converge to the optimal state-value/action-value function, then policy iteration will also do so. From &lt;a href=&#34;http://incompleteideas.net/book/the-book.html&#34;&gt;Reinforcement Learning: An Introduction by Sutton and Barto (2nd edition)&lt;/a&gt;:&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;Value iteration effectively combines, in each of its sweeps, one sweep of policy evaluation and one sweep of policy improvement. Faster convergence is often achieved by interposing multiple policy evaluation sweeps between each policy improvement sweep. In general, the entire class of truncated policy iteration algorithms can be thought of as sequences of sweeps, some of which use policy evaluation updates and some of which use value iteration updates. Because the max operation in (4.10) is the only difference between these updates this just means that the max operation is added to some sweeps of policy evaluation. All of these algorithms converge to an optimal policy for discounted finite MDPs.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;Truncated policy iteration is the general method with value iteration at one end of the spectrum where only one iteration of policy evaluation is done before updating the policy greedily, and policy iteration at the other end where policy evaluation runs until convergence before updating the policy greedily. We know that value iteration only has one fixed-point and only have to consider what happens when we add more iterations of policy evaluation before doing one step of value iteration (applying the Bellman optimality operator). One way to think about it, is that we will simply update the policy using more accurate state-values or action-values, and it&amp;rsquo;s hard to see how this would interfere with convergence.&lt;/p&gt;
&lt;p&gt;Another way to think about it, is that each step will either improve or be as good as the previous policy due to the policy improvement theorem (see chapters 3 and 4 in &lt;a href=&#34;http://incompleteideas.net/book/the-book.html&#34;&gt;Reinforcement Learning: An Introduction by Sutton and Barto (2nd edition)&lt;/a&gt;). And since each iteration essentially ends with a step of value iteration, we know that the algorithm will only get stuck in one fixed-point - the same as in value iteration.&lt;/p&gt;
&lt;p&gt;Beyond intuitive explanations, there are of course formal proofs. From chapter 4 in &lt;a href=&#34;https://github.com/MathFoundationRL/Book-Mathmatical-Foundation-of-Reinforcement-Learning&#34;&gt;Mathematical Foundations of Reinforcement Learning by Zhao&lt;/a&gt; (the proof is in the book):&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;The idea of the proof is to show that policy iteration converge faster than value iteration. Since value iteration has been proven to be convergent, the convergence of policy iteration immediately follows.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;p&gt;Note that faster here does not mean faster compute time or fewer total number of iterations, since the policy evaluation step potentially requires an infinite number of iterations, but instead that policy iteration requires equally many or fewer policy updates before convergence.&lt;/p&gt;
&lt;p&gt;That&amp;rsquo;s it for this post. I hope you enjoyed the read and please feel free to reach out to me at &lt;a href=&#34;mailto:c@cfml.se&#34;&gt;c@cfml.se&lt;/a&gt; with any questions or feedback!&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Neural Style Transfer 2</title>
      <link>https://cfml.se/neural-style-transfer-2/</link>
      <pubDate>Sat, 08 Aug 2020 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/neural-style-transfer-2/</guid>
      <description>&lt;iframe
    src=&#34;https://cfml.se/notebooks/neural_style_transfer_2.html&#34;
    width=&#34;100%&#34;
    scrolling=&#34;no&#34;
    onload=&#34;this.style.height=(this.contentDocument.body.scrollHeight + 30) + &#39;px&#39;;&#34;
    style=&#34;border:none;&#34;&gt;
&lt;/iframe&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Linear Regression Library Comparison</title>
      <link>https://cfml.se/linear-regression-library-comparison/</link>
      <pubDate>Thu, 16 Jul 2020 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/linear-regression-library-comparison/</guid>
      <description>&lt;iframe
    src=&#34;https://cfml.se/notebooks/linear_regression_lib_comparison.html&#34;
    width=&#34;100%&#34;
    scrolling=&#34;no&#34;
    onload=&#34;this.style.height=(this.contentDocument.body.scrollHeight + 30) + &#39;px&#39;;&#34;
    style=&#34;border:none;&#34;&gt;
&lt;/iframe&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Image Classification using PyTorch</title>
      <link>https://cfml.se/image-classification-using-pytorch/</link>
      <pubDate>Sun, 29 Mar 2020 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/image-classification-using-pytorch/</guid>
      <description>&lt;iframe
    src=&#34;https://cfml.se/notebooks/image_classification_pytorch.html&#34;
    width=&#34;100%&#34;
    scrolling=&#34;no&#34;
    onload=&#34;this.style.height=(this.contentDocument.body.scrollHeight + 30) + &#39;px&#39;;&#34;
    style=&#34;border:none;&#34;&gt;
&lt;/iframe&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Deep Q-Networks</title>
      <link>https://cfml.se/deep-q-networks/</link>
      <pubDate>Thu, 02 Jan 2020 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/deep-q-networks/</guid>
      <description>&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;Deep Q-Networks (DQN) are artificial neural networks (ANN) that utilizes a variant of Q-learning in order to update the parameters of the network. DQN was introduced in &lt;a href=https://arxiv.org/abs/1312.5602&gt;Playing Atari with Deep Reinforcement Learning&lt;/a&gt; (Mnih et al., 2013).&lt;/p&gt;
&lt;p&gt;The aim of this article is to describe the core concepts of DQN and how it can be applied to the OpenAI Gym CartPole problem. The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/openai_gym_problems&#34;&gt;https://github.com/CarlFredriksson/openai_gym_problems&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;approximate-rl-methods&#34;&gt;Approximate RL Methods&lt;/h2&gt;
&lt;p&gt;In my previous post &lt;a href=&#34;https://cfml.se/temporal-difference-learning/&#34;&gt;Temporal Difference Learning&lt;/a&gt;, I described the basic tabular version of Q-learning. Tabular reinforcement learning (RL) methods are great for problems with a small state space, but are not suited for problems where the state space is large. This is because of the memory required to store the table of action-values and the time it would take to fill the table.&lt;/p&gt;
&lt;p&gt;Many RL methods can be extended to be approximate instead of tabular by using function approximators, such as linear models or ANNs. This makes them better suited for problems with large state spaces. Tabular methods define the action-value estimates $Q(s,a)$ as a function of state $s$ and action $a$. Approximate methods define them as $Q(s,a,w)$, where $w$ is the model parameters of the function approximator, such as the weights and biases of a neural network. Instead of updating the action-values themselves, approximate methods update the model parameters $w$. This is usually done using some form of gradient descent.&lt;/p&gt;
&lt;p&gt;Assume that we know the true $q(s,a)$ and want our estimate $Q(s,a,w)$ to move towards it. A natural loss function $\mathcal{L}(w)$ to facilitate this is the mean squared error (MSE)&lt;/p&gt;
&lt;p&gt;$$
\mathcal{L}(w) = \frac{1}{2} \big[q(s,a) - Q(s,a,w) \big]^2
$$&lt;/p&gt;
&lt;p&gt;Let $\alpha$ be the learning rate, now we can do gradient descent on $\mathcal{L}(w)$&lt;/p&gt;
&lt;div&gt;
$$
\begin{aligned}
    w_{t+1} &amp;= w_t - \alpha \nabla \mathcal{L}(w) \\
        &amp;= w_t - \frac{1}{2} \alpha \nabla \big[q(s,a) - Q(s,a,w) \big]^2 \\
        &amp;= w_t + \alpha \big[q(s,a) - Q(s,a,w) \big] \nabla Q(s,a,w)
\end{aligned}
$$
&lt;/div&gt;
&lt;p&gt;Of course, if we knew the true $q(s,a)$, the problem would already be solved. Let us denote the target we want our action-value estimates to move towards with $Q_{target}$. Exchanging $Q_{target}$ with $q(s,a)$ in the update rule above gives us&lt;/p&gt;
&lt;div&gt;
$$
    w_{t+1} = w_t + \alpha \big[Q_{target} - Q(s,a,w) \big] \nabla Q(s,a,w)
$$
&lt;/div&gt;
&lt;p&gt;Let us now define $Q_{target}$ to be the target in Q-learning&lt;/p&gt;
&lt;div&gt;
$$
    Q_{target} = R + \gamma \operatorname{max}_a Q(S^\prime,a,w)
$$
&lt;/div&gt;
&lt;p&gt;This gives us&lt;/p&gt;
&lt;div&gt;
$$
    w_{t+1} = w_t + \alpha \big[R + \gamma \operatorname{max}_a Q(S^\prime,a,w) - Q(s,a,w) \big] \nabla Q(s,a,w)
$$
&lt;/div&gt;
&lt;p&gt;Which is the update rule for approximate Q-learning.&lt;/p&gt;
&lt;div class=&#34;pseudocode-container&#34;&gt;
    &lt;b&gt;Approximate version of Q-learning&lt;/b&gt;&lt;br&gt;
    &lt;div class=&#34;pseudocode&#34;&gt;
        Algorithm parameters: $\alpha \in (0,1]$, $\gamma \in (0,1]$, small $\epsilon &gt; 0$&lt;br&gt;
        Initialize model $Q(S,A,w)$ and model parameters $w \in \mathbb{R}^d$&lt;br&gt;
        Loop for each episode:
        &lt;div class=&#34;pseudocode-indent&#34;&gt;
            Initialize $S$&lt;br&gt;
            Loop until $S$ is terminal:
            &lt;div class=&#34;pseudocode-indent&#34;&gt;
                Choose $A \in \mathcal{A}(S)$ using a soft policy derived from $Q$ (e.g. $\epsilon$-greedy)&lt;br&gt;
                Take action $A$, observe $R$, $S^\prime$&lt;br&gt;
                If $S^\prime$ is terminal:
                &lt;div class=&#34;pseudocode-indent&#34;&gt;
                    $w \leftarrow w + \alpha \big[R - Q(S,A,w) \big] \nabla Q(S,A,w)$&lt;br&gt;
                &lt;/div&gt;
                Else:
                &lt;div class=&#34;pseudocode-indent&#34;&gt;
                    $w \leftarrow w + \alpha \big[R + \gamma \operatorname{max}_a Q(S^\prime,a,w) - Q(S,A,w) \big] \nabla Q(S,A,w)$&lt;br&gt;
                &lt;/div&gt;
                $S \leftarrow S^\prime$&lt;br&gt;
            &lt;/div&gt;
        &lt;/div&gt;
    &lt;/div&gt;
&lt;/div&gt;
&lt;h2 id=&#34;experience-replay&#34;&gt;Experience Replay&lt;/h2&gt;
&lt;p&gt;DQN uses the same update rule as the one described in approximate Q-learning above. The difference is that $S$, $A$, $R$, and $S^\prime$ are added to a memory instead of immediately used for parameter updating. The parameters are updated using randomly sampled batches of $S A R S^\prime$ experiences from memory. This is called experience replay.&lt;/p&gt;
&lt;p&gt;Experience replay has two major advantages. It makes more efficient use of previous experiences, since they are used for learning multiple times. It also has better convergence behavior when training a function approximator. One reason for this is that the experiences in a single trajectory are often correlated, but randomly sampling from the memory makes the data more like indepent and identically distributed (IID) data. This is important since most supervised learning convergence proofs assume IID data.&lt;/p&gt;
&lt;p&gt;One disadvantage with experience replay is that it is harder to use with multi-step learning algorithms, such as $Q(\lambda)$.&lt;/p&gt;
&lt;div class=&#34;pseudocode-container&#34;&gt;
    &lt;b&gt;DQN&lt;/b&gt;&lt;br&gt;
    &lt;div class=&#34;pseudocode&#34;&gt;
        Algorithm parameters: $\alpha, \gamma, \epsilon_{max}, \epsilon_{min}, \epsilon_{decay} \in (0,1]$, ${batch\_size} \geq 1$&lt;br&gt;
        Initialize ANN model $Q(S,A,w)$ and model parameters $w \in \mathbb{R}^d$&lt;br&gt;
        Initialize $\epsilon \leftarrow \epsilon_{max}$&lt;br&gt;
        Initialize memory&lt;br&gt;
        Loop for each episode:
        &lt;div class=&#34;pseudocode-indent&#34;&gt;
            Initialize $S$&lt;br&gt;
            Loop until $S$ is terminal:
            &lt;div class=&#34;pseudocode-indent&#34;&gt;
                Choose $A \in \mathcal{A}(S)$ using a soft policy derived from $Q$ (e.g. $\epsilon$-greedy)&lt;br&gt;
                Take action $A$, observe $R$, $S^\prime$&lt;br&gt;
                Add $(S,A,R,S^\prime)$ to memory&lt;br&gt;
                Sample random batch from memory&lt;br&gt;
                Loop for each $(S,A,R,S^\prime)$ in batch:
                &lt;div class=&#34;pseudocode-indent&#34;&gt;
                    If $S^\prime$ is terminal:
                    &lt;div class=&#34;pseudocode-indent&#34;&gt;
                        $w \leftarrow w + \alpha \big[R - Q(S,A,w) \big] \nabla Q(S,A,w)$&lt;br&gt;
                    &lt;/div&gt;
                    Else:
                    &lt;div class=&#34;pseudocode-indent&#34;&gt;
                        $w \leftarrow w + \alpha \big[R + \gamma \operatorname{max}_a Q(S^\prime,a,w) - Q(S,A,w) \big] \nabla Q(S,A,w)$&lt;br&gt;
                    &lt;/div&gt;
                &lt;/div&gt;
                $\epsilon \leftarrow \operatorname{max}(\epsilon_{min}, \epsilon \times \epsilon_{decay})$&lt;br&gt;
                $S \leftarrow S^\prime$&lt;br&gt;
            &lt;/div&gt;
        &lt;/div&gt;
    &lt;/div&gt;
&lt;/div&gt;
&lt;h2 id=&#34;openai-cartpole-problem&#34;&gt;OpenAI CartPole Problem&lt;/h2&gt;
&lt;p&gt;OpenAI Gym has many problems to try algorithms on. Let us apply DQN on the simple CartPole problem. Description from &lt;a href=&#34;https://gym.openai.com/envs/CartPole-v1/&#34;&gt;https://gym.openai.com/envs/CartPole-v1/&lt;/a&gt; :&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The system is controlled by applying a force of +1 or -1 to the cart. The pendulum starts upright, and the goal is to prevent it from falling over. A reward of +1 is provided for every timestep that the pole remains upright. The episode ends when the pole is more than 15 degrees from vertical, or the cart moves more than 2.4 units from the center.&lt;/p&gt;
&lt;/blockquote&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;random&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;gym&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;numpy&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;plt&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;collections&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;deque&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;keras.models&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Sequential&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;keras.layers&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;keras.optimizers&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Adam&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;ENV_NAME&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;CartPole-v1&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;NUM_EPISODES&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;100&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;MEMORY_SIZE&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NUM_EPISODES&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1000&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;BATCH_SIZE&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;20&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;GAMMA&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.95&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;LEARNING_RATE&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.001&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;EXPLORATION_MAX&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;1.0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;EXPLORATION_MIN&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.01&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;EXPLORATION_DECAY&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.995&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;class&lt;/span&gt; &lt;span class=&#34;nc&#34;&gt;DQNAgent&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state_space_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action_space_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;action_space_size&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action_space_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;exploration_rate&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;EXPLORATION_MAX&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;memory&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;deque&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;maxlen&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;MEMORY_SIZE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Sequential&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;24&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;input_shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state_space_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;24&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;action_space_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;linear&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;compile&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;mse&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Adam&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;lr&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;LEARNING_RATE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;select_action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Select action epsilon-greedily&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;rand&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;exploration_rate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;randrange&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;action_space_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Q_values&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;predict&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;argmax&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;remember&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state_next&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;done&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;memory&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state_next&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;done&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;experience_replay&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;memory&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;BATCH_SIZE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Update model parameters using random samples from memory&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sample&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;memory&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;BATCH_SIZE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state_next&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;done&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;Q_update&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;done&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;GAMMA&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;amax&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;predict&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state_next&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;Q_values&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;predict&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;Q_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;][&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Q_update&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;fit&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Q_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;verbose&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;update_exploration_rate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;exploration_rate&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;EXPLORATION_DECAY&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;exploration_rate&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;max&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;EXPLORATION_MIN&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;exploration_rate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;cartpole&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Create environment and agent&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;env&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;gym&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;make&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ENV_NAME&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;state_space_size&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;env&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;observation_space&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;action_space_size&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;env&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;action_space&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;dqn_agent&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;DQNAgent&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state_space_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action_space_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Simulate episodes&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;scores&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;NUM_EPISODES&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;episode&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;NUM_EPISODES&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;env&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reset&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;done&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;while&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;not&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;done&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;c1&#34;&gt;#env.render()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dqn_agent&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;select_action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;state_next&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;done&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;info&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;env&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;step&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;scores&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;episode&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;state_next&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state_next&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;dqn_agent&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;remember&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state_next&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;done&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;dqn_agent&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;experience_replay&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;dqn_agent&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;update_exploration_rate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state_next&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Episode: &lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{}&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;, Score: &lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{}&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;format&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;episode&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;scores&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;episode&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Plot score&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;arange&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;NUM_EPISODES&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;scores&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;xlabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Episode&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ylabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Score&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;show&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;vm&#34;&gt;__name__&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;cartpole&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/deep_q-networks/results.png&#34; alt=&#34;results&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;Thank you for reading and feel free to send me any questions.&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>ANN without ML Framework</title>
      <link>https://cfml.se/ann-without-ml-framework/</link>
      <pubDate>Sun, 15 Dec 2019 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/ann-without-ml-framework/</guid>
      <description>&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;The aim of this article is to show how you can create an artificial neural network (ANN) without using a machine learning (ML) framework such as TensorFlow or PyTorch. This is not practical for most projects, but it&amp;rsquo;s a great exercise for building intuition about neural networks. The implemented network will be a simple feedforward neural network with a few customizable hyperparameters such as the number of hidden layers and their sizes. The network will be used to solve a simple regression problem.&lt;/p&gt;
&lt;p&gt;The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/ann_without_ml_framework&#34;&gt;https://github.com/CarlFredriksson/ann_without_ml_framework&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;initialize-parameters&#34;&gt;Initialize Parameters&lt;/h2&gt;
&lt;p&gt;The only import we will need to create the network is&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;numpy&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;np&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Let&amp;rsquo;s start by creating a class for our network.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;class&lt;/span&gt; &lt;span class=&#34;nc&#34;&gt;SimpleNeuralNetwork&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Inside this class we are going to write some methods. Let&amp;rsquo;s start with the &lt;code&gt;__init__&lt;/code&gt; method that initializes the weights and biases of the network. The weights are initialized to normally distributed random numbers and the biases are initialized to zero.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;input_layer_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;hidden_layer_sizes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;output_layer_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;W&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;b&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;hidden_layer_sizes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add_layer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;input_layer_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;output_layer_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;elif&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;hidden_layer_sizes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add_layer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;input_layer_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;hidden_layer_sizes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;hidden_layer_sizes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add_layer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;hidden_layer_sizes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;hidden_layer_sizes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add_layer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;hidden_layer_sizes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;output_layer_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;add_layer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;previous_layer_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;layer_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;W&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;randn&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;previous_layer_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;layer_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.01&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;b&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;layer_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)))&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h2 id=&#34;forward-propagation&#34;&gt;Forward Propagation&lt;/h2&gt;
&lt;p&gt;We have initialized the necessary member variables to write a forward propagation method. Let $m$ denote the batch size, $n$ the input layer size (layer 0), and $o$ the output layer size. Let $X \in \mathbb{R}^{m \times n}$ denote a batch of input vectors and $Y \in \mathbb{R}^{m \times o}$ their corresponding target output vectors. Let $W^{[l]} \in \mathbb{R}^{\text{size of }l-1 \times \text{size of }l}$ denote the weights, $b^{[l]} \in \mathbb{R}^{1 \times \text{size of }l}$ the biases, and $A^{[l]} \in \mathbb{R}^{m \times \text{size of }l}$ the activations of layer $l \in {1,2,\ldots,L}$ with layer $l = L$ being the output layer. All layers will use the &lt;em&gt;ReLU&lt;/em&gt; activation function except the output layer, which will be linear since we will be using the network on a regression problem. More formally&lt;/p&gt;
&lt;p&gt;$$
\begin{aligned}
A^{[0]} &amp;amp;= X \\
Z^{[l]} &amp;amp;= A^{[l-1]} W^{[l]} + b^{[l]} \qquad \forall l \in {1,2,\ldots,L} \\
A^{[l]} &amp;amp;= max(0, Z) \qquad \forall l \in {1,2,\ldots,L-1} \\
A^{[L]} &amp;amp;= Z
\end{aligned}
$$&lt;/p&gt;
&lt;p&gt;Let $\mathcal{L}$ denote the loss. The loss function we will use is mean squared error (MSE). More formally&lt;/p&gt;
&lt;p&gt;$$
\mathcal{L} = \frac{1}{2 m} \sum_{i=1}^m \sum_{j=1}^{o} \big[Y_{i,j} - A^{[L]}_{i,j} \big]^2
$$&lt;/p&gt;
&lt;p&gt;We could divide by $o$ instead of 2 to make it more of a true mean, but dividing by 2 makes the gradient cleaner. Note that we store the computed $Z$ and $A$ values in order to use them for computing the gradient by backpropagation.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;propagate_forward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Z_cache&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;W&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;A_cache&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;W&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;A_prev&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;W&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;A_prev&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;W&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;b&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;A&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# If not at the output layer, apply relu&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;W&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;A&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;maximum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Z_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;A_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;A&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;A_prev&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;A&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;/&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;*&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;A_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;**&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;A_cache&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h2 id=&#34;backpropagation&#34;&gt;Backpropagation&lt;/h2&gt;
&lt;p&gt;We will use backpropagation to efficiently compute the gradient of the loss function. Backpropagation might seem mysterious at first, but the idea is simply to use the chain rule of derivatives and the stored $Z$ and $A$ values for efficiency. We can work our way backwards&lt;/p&gt;
&lt;p&gt;$$
\begin{aligned}
\frac{\delta \mathcal{L}}{\delta A^{[L]}} &amp;amp;= \frac{1}{m} \sum A^{[L]} - Y \\
\frac{\delta \mathcal{L}}{\delta Z^{[L]}} &amp;amp;= \frac{\delta \mathcal{L}}{\delta A^{[L]}} \frac{\delta A^{[L]}}{\delta Z^{[L]}} = \frac{\delta \mathcal{L}}{\delta A^{[L]}} \qquad \text{(since the output layer is linear)} \\
\frac{\delta \mathcal{L}}{\delta W^{[L]}} &amp;amp;= \frac{\delta \mathcal{L}}{\delta Z^{[L]}} \frac{\delta Z^{[L]}}{\delta W^{[L]}} = \frac{\delta \mathcal{L}}{\delta Z^{[L]}} A^{[L-1]} \\
\frac{\delta \mathcal{L}}{\delta b^{[L]}} &amp;amp;= \frac{\delta \mathcal{L}}{\delta Z^{[L]}} \frac{\delta Z^{[L]}}{\delta b^{[L]}} = \frac{\delta \mathcal{L}}{\delta Z^{[L]}}  \\
\frac{\delta \mathcal{L}}{\delta A^{[L-1]}} &amp;amp;= \frac{\delta \mathcal{L}}{\delta Z^{[L]}} \frac{\delta Z^{[L]}}{\delta A^{[L-1]}} = \frac{\delta \mathcal{L}}{\delta Z^{[L]}} W^{[L-1]} \\
\frac{\delta \mathcal{L}}{\delta Z^{[L-1]}} &amp;amp;= \frac{\delta \mathcal{L}}{\delta A^{[L-1]}} \frac{\delta A^{[L-1]}}{\delta Z^{[L-1]}} = \frac{\delta \mathcal{L}}{\delta A^{[L-1]}} \frac{\delta max(0,Z^{[L-1]})}{\delta Z^{[L-1]}} \\
\frac{\delta \mathcal{L}}{\delta W^{[L-1]}} &amp;amp;= \frac{\delta \mathcal{L}}{\delta Z^{[L-1]}} \frac{\delta Z^{[L-1]}}{\delta W^{[L-1]}} = \frac{\delta \mathcal{L}}{\delta Z^{[L-1]}} A^{[L-2]} \\
\frac{\delta \mathcal{L}}{\delta b^{[L-1]}} &amp;amp;= \frac{\delta \mathcal{L}}{\delta Z^{[L-1]}} \frac{\delta Z^{[L-1]}}{\delta b^{[L-1]}} = \frac{\delta \mathcal{L}}{\delta Z^{[L-1]}}  \\
&amp;amp;\vdots
\end{aligned}
$$&lt;/p&gt;
&lt;p&gt;A pattern emerges and the computation of partial derivatives can be implemented iteratively. Some of the matrix multiplications above might not have matching dimensions but I find it more productive to focus on those details when implementing, since we probably want to change the order of some computations to make it more efficient anyway (for example by not averaging over examples immediately). This is not the only possible implementation and I encourage you to explore alternatives.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;propagate_backward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;A_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;dW&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;W&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;db&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;b&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# dA is short for dL/dA etc.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;dA&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;A_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;dZ&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dA&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;reversed&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;A_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;A&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;A_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;dW&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;/&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;A&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;T&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dZ&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;db&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;/&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dZ&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;dA&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dZ&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;W&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;T&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;dZ&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;multiply&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dA&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;relu_derivative&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;dW&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;/&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;T&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dZ&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;db&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;/&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dZ&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dW&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;db&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;relu_derivative&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h2 id=&#34;training-step&#34;&gt;Training Step&lt;/h2&gt;
&lt;p&gt;During training we will alternate forward propagation and backpropagation. Let&amp;rsquo;s create a method that combines them for convenience.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;propagate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;A_cache&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;propagate_forward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;dW&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;db&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;propagate_backward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;A_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;A_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dW&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;db&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;A single training step consists of one combined propagation and one gradient descent step.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;predictions&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dW&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;db&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;neural_network&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;propagate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;neural_network&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;W&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;neural_network&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;W&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dW&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;neural_network&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;b&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;db&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h2 id=&#34;regression-problem&#34;&gt;Regression Problem&lt;/h2&gt;
&lt;p&gt;Let&amp;rsquo;s generate some quadratic data with random noise added.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;generate_random_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;linspace&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;200&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;**&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;noise&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;normal&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;noise&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;We can now use our &lt;code&gt;SimpleNeuralNetwork&lt;/code&gt; class to create model that can fit to the generated data.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;vm&#34;&gt;__name__&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;generate_random_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;output/data.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;neural_network&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;SimpleNeuralNetwork&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.01&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;num_iterations&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10000&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_iterations&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;predictions&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dW&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;db&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;neural_network&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;propagate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;neural_network&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;W&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;neural_network&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;W&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dW&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;neural_network&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;b&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;db&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;%&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1000&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;loss:&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;A_cache&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;neural_network&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;propagate_forward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot_results&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;A_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;output/results.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/ann_without_ml_framework/results.png&#34; alt=&#34;results&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;I hope you found this article helpful. Thank you for reading and feel free to send me any questions.&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Temporal Difference Learning</title>
      <link>https://cfml.se/temporal-difference-learning/</link>
      <pubDate>Sun, 10 Nov 2019 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/temporal-difference-learning/</guid>
      <description>&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;&amp;ldquo;If one had to identify one idea as central and novel to reinforcement learning, it would
undoubtedly be &lt;em&gt;temporal-difference&lt;/em&gt; (TD) learning&amp;rdquo; (Reinforcement Learning: An Introduction, Sutton, Barto, 2018).&lt;/p&gt;
&lt;p&gt;TD learning combines ideas from &lt;a href=&#34;https://cfml.se/monte-carlo-methods/&#34;&gt;Monte Carlo Methods&lt;/a&gt; (MC methods) and &lt;a href=&#34;https://cfml.se/dynamic-programming/&#34;&gt;Dynamic Programming&lt;/a&gt; (DP). Like DP and MC methods, TD methods are a form of generalized policy iteration (GPI), which means that they alternate policy evaluation (estimation of value functions) and policy improvement (using value estimates to improve a policy). Like MC methods, TD methods don&amp;rsquo;t need to know the underlying dynamics of the model and instead learn from experience. Unlike MC methods but like DP, TD methods use bootstrapping, which means that value estimates are updated using other value estimates.&lt;/p&gt;
&lt;p&gt;Unlike MC methods, which update value estimates and the target policy after each episode, TD methods update estimates and the target policy after each time step. This makes TD methods work on continuous tasks as well as episodic tasks. On episodic tasks, where MC methods can work as well, TD methods usually converge faster.&lt;/p&gt;
&lt;p&gt;The aim of this article is to be a short introduction to TD learning. To read more about this subject, I recommend (Sutton, Barto, 2018). Later in the article there is going to be an implemented example of a couple of TD methods that solves a gridworld problem. The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/temporal_difference_learning&#34;&gt;https://github.com/CarlFredriksson/temporal_difference_learning&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;sarsa&#34;&gt;Sarsa&lt;/h2&gt;
&lt;p&gt;Sarsa is an on-policy TD method that uses action-value estimates $Q_\pi(s,a)$ to improve the target policy $\pi$. To assure that the agent explores sufficiently, the policy $\pi$ has to be soft, which means that all actions in every state has a probability greater than zero of being selected. A commonly used type of soft policies are $\epsilon$-greedy policies that select the action that maximizes $Q_\pi(s,a)$ with probability $1-\epsilon$ and a random action with probability $\epsilon$.&lt;/p&gt;
&lt;p&gt;In Sarsa, the action-value estimates are improved using the update rule&lt;/p&gt;
&lt;div&gt;
$$
    Q(S_t,A_t) \leftarrow Q(S_t,A_t) + \alpha \big[R_{t+1} + \gamma Q(S_{t+1},A_{t+1}) - Q(S_t,A_t) \big]
$$
&lt;/div&gt;
&lt;p&gt;The name Sarsa comes from the fact that $S_t,A_t,R_{t+1},S_{t+1},A_{t+1}$ are used in the update rule, which spells out SARSA if the subscripts are dropped.&lt;/p&gt;
&lt;p&gt;Pseudocode for Sarsa with the target policy being any soft policy derived from $Q$ (e.g. $\epsilon$-greedy) is given below.&lt;/p&gt;
&lt;div class=&#34;pseudocode-container&#34;&gt;
    &lt;b&gt;Sarsa&lt;/b&gt;&lt;br&gt;
    &lt;div class=&#34;pseudocode&#34;&gt;
        Algorithm parameters: $\alpha \in (0,1]$, $\gamma \in (0,1]$, small $\epsilon &gt; 0$&lt;br&gt;
        Initialize $Q(s,a)$, for all $s \in \mathcal{S}^+$, $a \in \mathcal{A}(s)$, arbitrarily except that $Q(terminal,\dot{}) \leftarrow 0$&lt;br&gt;
        Loop for each episode:
        &lt;div class=&#34;pseudocode-indent&#34;&gt;
            Initialize $S$&lt;br&gt;
            Choose $A \in \mathcal{A}(S)$ using a soft policy derived from $Q$ (e.g. $\epsilon$-greedy)&lt;br&gt;
            Loop until $S$ is terminal:
            &lt;div class=&#34;pseudocode-indent&#34;&gt;
                Take action $A$, observe $R$, $S^\prime$&lt;br&gt;
                Choose $A^\prime \in \mathcal{A}(S^\prime)$ using a soft policy derived from $Q$ (e.g. $\epsilon$-greedy)&lt;br&gt;
                $Q(S,A) \leftarrow Q(S,A) + \alpha \big[R + \gamma Q(S^\prime,A^\prime) - Q(S,A) \big]$&lt;br&gt;
                $S \leftarrow S^\prime$&lt;br&gt;
                $A \leftarrow A^\prime$
            &lt;/div&gt;
        &lt;/div&gt;
    &lt;/div&gt;
&lt;/div&gt;
&lt;h2 id=&#34;q-learning&#34;&gt;Q-learning&lt;/h2&gt;
&lt;p&gt;Q-learning is an off-policy TD method. Thus, it uses two different policies, the the target policy $\pi$ and the behavior policy $b$. The target policy is used for value estimation and is the target for improvement, while the behavior policy is used for action selection. The behavior policy has to be soft in order to maintain sufficient exploration. The target policy can be completely greedy.&lt;/p&gt;
&lt;p&gt;In Q-learning, the action-value estimates are improved using the update rule&lt;/p&gt;
&lt;div&gt;
$$
    Q(S_t,A_t) \leftarrow Q(S_t,A_t) + \alpha \big[R_{t+1} + \gamma \underset{a}{\operatorname{max}} Q(S_{t+1},a) - Q(S_t,A_t) \big]
$$
&lt;/div&gt;
&lt;p&gt;The name Q-learning simply comes from the use of action value estimates $Q$.&lt;/p&gt;
&lt;p&gt;Pseudocode for Q-learning with the target policy being the greedy policy derived from $Q$ and the behavior policy being any soft policy derived from $Q$ (e.g. $\epsilon$-greedy) is given below.&lt;/p&gt;
&lt;div class=&#34;pseudocode-container&#34;&gt;
    &lt;b&gt;Q-learning&lt;/b&gt;&lt;br&gt;
    &lt;div class=&#34;pseudocode&#34;&gt;
        Algorithm parameters: $\alpha \in (0,1]$, $\gamma \in (0,1]$, small $\epsilon &gt; 0$&lt;br&gt;
        Initialize $Q(s,a)$, for all $s \in \mathcal{S}^+$, $a \in \mathcal{A}(s)$, arbitrarily except that $Q(terminal,\dot{}) \leftarrow 0$&lt;br&gt;
        Loop for each episode:
        &lt;div class=&#34;pseudocode-indent&#34;&gt;
            Initialize $S$&lt;br&gt;
            Loop until $S$ is terminal:
            &lt;div class=&#34;pseudocode-indent&#34;&gt;
                Choose $A \in \mathcal{A}(S)$ using a soft policy derived from $Q$ (e.g. $\epsilon$-greedy)&lt;br&gt;
                Take action $A$, observe $R$, $S^\prime$&lt;br&gt;
                $Q(S,A) \leftarrow Q(S,A) + \alpha \big[R + \gamma \operatorname{max}_a Q(S^\prime,a) - Q(S,A) \big]$&lt;br&gt;
                $S \leftarrow S^\prime$&lt;br&gt;
            &lt;/div&gt;
        &lt;/div&gt;
    &lt;/div&gt;
&lt;/div&gt;
&lt;h2 id=&#34;windy-gridworld-example&#34;&gt;Windy Gridworld Example&lt;/h2&gt;
&lt;p&gt;We are going to use Example 6.5: Windy Gridworld (Sutton, Barto, 2018) as a problem for showcasing Python implementations of Sarsa and Q-learning. It is a standard gridworld problem with the addition of wind that moves the agent upward a different number of cells depending on which column the agent is on after taking an action. The state is the cell position of the agent and the available actions are UP, DOWN, LEFT, RIGHT in every state. The actions deterministically move the agent in their respective directions before wind is applied. If the agent would move outside of the gridworld (by an action or by the wind), it stays in same cell as before moving. The task is episodic and undiscounted with a reward of -1 given after every action, except when the goal state G is reached. The gridworld is visualized in the following image, with the blue line representing the optimal trajectory when starting from state S.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/temporal_difference_learning/windy_gridworld.png&#34; alt=&#34;windy_gridworld&#34;&gt;&lt;/p&gt;
&lt;p&gt;Note that MC methods cannot easily be used for this problem since not all policies are guaranteed to end up in a terminal state (the goal state). This can lead to infinite episodes if the number of steps isn&amp;rsquo;t limited. TD methods don&amp;rsquo;t have the same problem since they learn during episodes and will eventually end up in the goal state. Below is an implementation of the gridworld environment agents can interact with by calling the &lt;code&gt;take_action&lt;/code&gt; function.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;class&lt;/span&gt; &lt;span class=&#34;nc&#34;&gt;Action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Enum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;UP&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;DOWN&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;LEFT&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;RIGHT&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;class&lt;/span&gt; &lt;span class=&#34;nc&#34;&gt;Environment&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;starting_state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;wind&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;goal&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;7&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;starting_state&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;adjust_position&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# X position&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;elif&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;9&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;9&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Y Position&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;elif&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;6&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;6&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;take_action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Move agent from action&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Action&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;UP&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;elif&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Action&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;DOWN&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;elif&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Action&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;LEFT&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;elif&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Action&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;RIGHT&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;ERROR: INVALID ACTION SELECTED!&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;adjust_position&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Move agent from wind&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;!=&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;goal&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;wind&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;adjust_position&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;goal&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Implementations of a Sarsa agent and a Q-learning agent:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;class&lt;/span&gt; &lt;span class=&#34;nc&#34;&gt;SarsaAgent&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;alpha&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;alpha&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;alpha&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;7&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;4&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;step&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Take action&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;s&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;r&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;take_action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;s_prime&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Select next action and update Q&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;a_prime&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epsilon_greedy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;s_prime&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;s&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;s&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;alpha&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;r&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;s_prime&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;s_prime&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a_prime&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;s&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;s&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a_prime&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Check if goal has been reached&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;r&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;class&lt;/span&gt; &lt;span class=&#34;nc&#34;&gt;QLearningAgent&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;alpha&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;alpha&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;alpha&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;7&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;4&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;step&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Select action&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;s&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;a&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epsilon_greedy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;s&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Take action&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;r&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;take_action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;s_prime&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Update Q&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;s&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;s&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;alpha&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;r&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;max&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;s_prime&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;s_prime&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:])&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;s&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;s&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Check if goal has been reached&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;r&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;epsilon_greedy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;s&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;a&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;argmax&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;s&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;s&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;rand&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;a&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;randint&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;4&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Some code for running the agents and plotting the results:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;vm&#34;&gt;__name__&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;NUM_EPISODES&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;200&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;MAX_NUM_STEPS&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1000&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ALPHA&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;EPSILON&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;STARTING_STATE&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;agents&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;SarsaAgent&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ALPHA&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;EPSILON&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;QLearningAgent&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ALPHA&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;EPSILON&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;num_steps&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;agents&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NUM_EPISODES&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Agents learning&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;agent_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;agent&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;enumerate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;agents&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;episode&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;NUM_EPISODES&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Environment&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;STARTING_STATE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_steps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;agent_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;episode&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;MAX_NUM_STEPS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;n&#34;&gt;at_goal&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;agent&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;step&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;at_goal&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                    &lt;span class=&#34;k&#34;&gt;break&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Plot results&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;NUM_EPISODES&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_steps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;label&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Sarsa&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;NUM_EPISODES&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_steps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;label&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Q-learning&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;legend&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;xlabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Episode&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ylabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Number of steps to reach goal&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;show&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/temporal_difference_learning/results.png&#34; alt=&#34;results&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;We have described two of the most popular TD methods (and reinforcement learning methods in general), the on-policy method Sarsa and the off-policy method Q-learning. They both are one-step, model-free, and tabular TD methods. TD methods can be extended to &lt;em&gt;n&lt;/em&gt;-step forms, which link them to MC methods, and to forms that include a model of the environment, which link them to DP. TD methods can also be extended to be non-tabular and instead use various forms of function approximation, for example by using artificial neural networks.&lt;/p&gt;
&lt;p&gt;Thank you for reading and feel free to send me any questions.&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Monte Carlo Methods</title>
      <link>https://cfml.se/monte-carlo-methods/</link>
      <pubDate>Sat, 02 Nov 2019 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/monte-carlo-methods/</guid>
      <description>&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/monte_carlo_methods/final_policy.png&#34; alt=&#34;final_policy&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;Monte Carlo (MC) methods are a broad class of algorithms that uses repeated random sampling to compute numerical results. The aim of this article is to be a short introduction to MC methods in reinforcement learning (RL). To read more about this subject, I recommend Reinforcement Learning: An Introduction (Sutton, Barto, 2018).&lt;/p&gt;
&lt;p&gt;MC methods have three main advantages over &lt;a href=&#34;https://cfml.se/dynamic-programming/&#34;&gt;Dynamic Programming&lt;/a&gt; (DP):&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;They can be used to learn optimal policies by interacting with the actual environment, without having any knowledge of the underlying &lt;a href=&#34;https://cfml.se/markov-decision-processes/&#34;&gt;Markov Decision Process&lt;/a&gt; (MDP).&lt;/li&gt;
&lt;li&gt;They can be used with simulated environments that only need to generate sample transitions. Generating sample transitions is often much easier than completely specifying the dynamics function $p$ of the MDP model.&lt;/li&gt;
&lt;li&gt;They can be focused on a small subset of states. States of special interest can be evaluated without the computational complexity of accurately evaluating all states.&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;Like DP, MC methods are a form of generalized policy iteration (GPI), which means that they alternate policy evaluation (estimation of value functions) and policy improvement (using value estimates to improve a policy). Unlike DP, MC methods do not use bootstrapping. In other words, value estimates are not computed using other value estimates.&lt;/p&gt;
&lt;p&gt;We will be primarily interested in estimating the action-value function $q_\pi(s,a)$ since we want to be able to pick the actions that maximize expected return. We denote the estimates by $Q_\pi(s,a)$ and compute them by averaging sample returns in each state-action pair similarily to contextual &lt;a href=&#34;https://cfml.se/multi-armed-bandits/&#34;&gt;Multi-armed Bandits&lt;/a&gt;. The difference is that the underlying model is an MDP, which means that the states are interrelated.&lt;/p&gt;
&lt;p&gt;This article is going to focus on MC methods for episodic tasks. That is, tasks where the agent eventually reach a terminal state. After each episode is finished, the action-value estimates and the policy is updated. The action-value estimate of a state-action pair is updated such that it&amp;rsquo;s the return following the first visit, or every visit to the pair, averaged over all episodes. A visit to a state action pair $(s,a)$ means that the agent was in state $s$ and took the action $a$. Both the first-visit method and the every-visit method converge to the actual action values as the number of visits to each state-action pair approaches infinity. Like DP, we don&amp;rsquo;t need perfect estimates and can truncate the value estimation before the policy is improved. Usually we will update the policy for a state immediately after an action-value has been updated in that state.&lt;/p&gt;
&lt;p&gt;Later in the article there is going to be an implemented example of a Monte Carlo method that solves a simple version of blackjack. The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/monte_carlo_methods&#34;&gt;https://github.com/CarlFredriksson/monte_carlo_methods&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;on-policy-methods&#34;&gt;On-policy Methods&lt;/h2&gt;
&lt;p&gt;The objective of all RL methods is to find policies that maximize expected return. To achieve this, agents have to balance exploitation of current knowledge (selecting actions that maximizes $Q_\pi$) with exploration of other actions that have lower action-value estimates in order to learn better estimates. This is called the exploration versus exploitation tradeoff.&lt;/p&gt;
&lt;p&gt;There are two approaches for ensuring exploration, &lt;em&gt;on-policy&lt;/em&gt; and &lt;em&gt;off-policy&lt;/em&gt; methods. On-policy methods use a single policy $\pi$ that is used both for selecting actions and for evaluation and improvement. Off-policy methods have two policies, one that is evaluated and improved and one that is used for action selection. In this section we are going to look at on-policy methods.&lt;/p&gt;
&lt;h3 id=&#34;exploring-starts&#34;&gt;Exploring Starts&lt;/h3&gt;
&lt;p&gt;To be able to estimate the action-value of a state-action pair, the pair needs to be visited. On way to make sure that the agent visits all state-action pairs is to do &lt;em&gt;exploring starts&lt;/em&gt;. This is done by starting the episodes in a random state and selecting a random starting action such that every state-action pair has a probability greater than zero of being selected. Exploring starts is useful for some problems but cannot be relied upon in general, particularly when learning from interaction with an actual environment.&lt;/p&gt;
&lt;p&gt;Pseudocode for MC with exploring starts:&lt;/p&gt;
&lt;div class=&#34;pseudocode-container&#34;&gt;
    &lt;b&gt;On-policy MC with exploring starts&lt;/b&gt;&lt;br&gt;
    &lt;div class=&#34;pseudocode&#34;&gt;
        Initialize:
        &lt;div class=&#34;pseudocode-indent&#34;&gt;
            $\pi(s) \in \mathcal{A}(s)$ arbitrarily for all $s \in \mathcal{S}$&lt;br&gt;
            $Q(s,a) \in \mathbb{R}$ arbitrarily for all $s \in \mathcal{S}$, $a \in \mathcal{A}(s)$&lt;br&gt;
            $N(s,a) \leftarrow 0$ for all $s \in \mathcal{S}$, $a \in \mathcal{A}(s)$
        &lt;/div&gt;
        Loop for multiple episodes:
        &lt;div class=&#34;pseudocode-indent&#34;&gt;
            Select $(S_0,A_0)$ randomly such that all pairs have probability &gt; 0&lt;br&gt;
            Generate an episode from $(S_0,A_0)$, following $\pi: S_0,A_0,R_1,\ldots,S_{T-1},A_{T-1},R_T$&lt;br&gt;
            $G \leftarrow 0$&lt;br&gt;
            Loop for each step of the episode, $t=T-1,T-2,\ldots,0$:
            &lt;div class=&#34;pseudocode-indent&#34;&gt;
                $G \leftarrow \gamma G + R_{t+1}$&lt;br&gt;
                If $(S_t,A_t) \notin \{(S_0,A_0),(S_1,A_1),\ldots,(S_{t-1},A_{t-1})\}$:
                &lt;div class=&#34;pseudocode-indent&#34;&gt;
                    $N(S_t,A_t) \leftarrow N(S_t,A_t) + 1$&lt;br&gt;
                    $Q(S_t,A_t) \leftarrow Q(S_t,A_t) + \frac{1}{N(S_t,A_t)} [G - Q(S_t,A_t)]$&lt;br&gt;
                    $\pi(S_t) \leftarrow argmax_a Q(S_t,a)$ (ties broken arbitrarily)
                &lt;/div&gt;
            &lt;/div&gt;
        &lt;/div&gt;
    &lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;If you want discounting, set $0 &amp;lt; \gamma &amp;lt; 1$. If you don&amp;rsquo;t want discounting, which is often the case for episodic tasks, set $\gamma = 1$. The algorithm above uses the first-visit method. If you want to use the every-visit method instead, simply remove the if-statement that checks if a state-action pair has been visited in an earlier time step.&lt;/p&gt;
&lt;h3 id=&#34;soft-policies&#34;&gt;Soft Policies&lt;/h3&gt;
&lt;p&gt;A common way to make sure that the agent explores is to use a soft policy. Soft policies satisfy $\pi(a|s) &amp;gt; 0$ for all $s \in \mathcal{S}$ and all $a \in \mathcal{A}(s)$. The probabilities are often shifted closer and closer to a deterministic policy.&lt;/p&gt;
&lt;p&gt;$\epsilon$-greedy policies are soft policies that select the action that maximizes $Q_\pi$ with probability $1 - \epsilon$, and a random action with probability $\epsilon$.&lt;/p&gt;
&lt;p&gt;Pseudocode for on-policy MC with $\epsilon$-greedy policy:&lt;/p&gt;
&lt;div class=&#34;pseudocode-container&#34;&gt;
    &lt;b&gt;On-policy MC with $\epsilon$-greedy policy&lt;/b&gt;&lt;br&gt;
    &lt;div class=&#34;pseudocode&#34;&gt;
        Initialize:
        &lt;div class=&#34;pseudocode-indent&#34;&gt;
            $\pi(a|s)$ to an arbitrary soft policy&lt;br&gt;
            $Q(s,a) \in \mathbb{R}$ arbitrarily for all $s \in \mathcal{S}$, $a \in \mathcal{A}(s)$&lt;br&gt;
            $N(s,a) \leftarrow 0$ for all $s \in \mathcal{S}$, $a \in \mathcal{A}(s)$
        &lt;/div&gt;
        Loop for multiple episodes:
        &lt;div class=&#34;pseudocode-indent&#34;&gt;
            Generate an episode following $\pi: S_0,A_0,R_1,\ldots,S_{T-1},A_{T-1},R_T$&lt;br&gt;
            $G \leftarrow 0$&lt;br&gt;
            Loop for each step of the episode, $t=T-1,T-2,\ldots,0$:
            &lt;div class=&#34;pseudocode-indent&#34;&gt;
                $G \leftarrow \gamma G + R_{t+1}$&lt;br&gt;
                If $(S_t,A_t) \notin \{(S_0,A_0),(S_1,A_1),\ldots,(S_{t-1},A_{t-1})\}$:
                &lt;div class=&#34;pseudocode-indent&#34;&gt;
                    $N(S_t,A_t) \leftarrow N(S_t,A_t) + 1$&lt;br&gt;
                    $Q(S_t,A_t) \leftarrow Q(S_t,A_t) + \frac{1}{N(S_t,A_t)} [G - Q(S_t,A_t)]$&lt;br&gt;
                    $A^* \leftarrow argmax_a Q(S_t,a)$ (ties broken arbitrarily)&lt;br&gt;
                    $\pi(A^*|S_t) \leftarrow 1 - \epsilon + \epsilon / |\mathcal{A}(S_t)|$&lt;br&gt;
                    For all $a \neq A^* \in \mathcal{A}(S_t)$:
                    &lt;div class=&#34;pseudocode-indent&#34;&gt;
                        $\pi(a|S_t) \leftarrow \epsilon / |\mathcal{A}(S_t)|$
                    &lt;/div&gt;
                &lt;/div&gt;
            &lt;/div&gt;
        &lt;/div&gt;
    &lt;/div&gt;
&lt;/div&gt;
&lt;h2 id=&#34;off-policy-methods&#34;&gt;Off-policy Methods&lt;/h2&gt;
&lt;p&gt;Off-policy methods have two policies, the &lt;em&gt;target policy&lt;/em&gt; $\pi$ and the &lt;em&gt;behavior policy&lt;/em&gt; $b$. The policy that is evaluated and improved is $\pi$ while $b$ is used for action selection. $\pi$ is often a greedy deterministic policy while $b$ has to be a soft policy in order for the agent to explore sufficiently.&lt;/p&gt;
&lt;h3 id=&#34;importance-sampling&#34;&gt;Importance Sampling&lt;/h3&gt;
&lt;p&gt;If action-value estimates are updated in the same way as on-policy methods while following $b$, they would be estimates of $q_b$ rather than the $q_\pi$ we are interested in. To correctly estimate $q_\pi$, we need to use &lt;em&gt;importance sampling&lt;/em&gt;. When using importance sampling, returns are weighted by the relative probability of their trajectories occuring under the target and behavior policies. This relative probability is called the &lt;em&gt;importance-sampling ratio&lt;/em&gt;.&lt;/p&gt;
&lt;p&gt;Given a starting state-action pair $S_t,A_t$, the probability of the subsequent state-action trajectory, $S_{t+1},A_{t+1},S_{t+2},A_{t+2},\ldots,S_T$, occuring while following any policy $\pi$ is&lt;/p&gt;
&lt;div&gt;
$$
\begin{aligned}
    Pr&amp;\{S_{t+1},A_{t+1},S_{t+2},A_{t+2},\ldots,S_T | S_t,A_t,A_{t+1:T-1} \sim \pi \} \\
    &amp;= p(S_{t+1} | S_t,A_t) \pi(A_{t+1} | S_{t+1}) p(S_{t+2} | S_{t+1},A_{t+1}) \pi(A_{t+2} | S_{t+2}) \ldots p(S_T | S_{T-1},A_{T-1}) \\
    &amp;= p(S_{t+1} | S_t,A_t) \prod_{k=t+1}^{T-1} \pi(A_k | S_k) p(S_{k+1} | S_k,A_k)
\end{aligned}
$$
&lt;/div&gt;
&lt;p&gt;where&lt;/p&gt;
&lt;div&gt;
$$
    p(S_{t+1} | S_t,A_t) = \sum_r p(S_{t+1},R_{t+1} = r | S_t,A_t)
$$
&lt;/div&gt;
&lt;p&gt;Thus, the importance-sampling ratio is&lt;/p&gt;
&lt;div&gt;
$$
\begin{aligned}
    \rho_{t:T-1} &amp;= \frac{p(S_{t+1} | S_t,A_t) \prod_{k=t+1}^{T-1} \pi(A_k | S_k) p(S_{k+1} | S_k,A_k)}{p(S_{t+1} | S_t,A_t) \prod_{k=t+1}^{T-1} b(A_k | S_k) p(S_{k+1} | S_k,A_k)} \\
    &amp;= \frac{\prod_{k=t+1}^{T-1} \pi(A_k | S_k)}{\prod_{k=t+1}^{T-1} b(A_k | S_k)}
\end{aligned}
$$
&lt;/div&gt;
&lt;p&gt;The state-transition probabilities cancel and we are left with a definition that doesn&amp;rsquo;t require any knowledge about the MDP dynamics.&lt;/p&gt;
&lt;p&gt;There are two types of importance sampling, &lt;em&gt;ordinary importance sampling&lt;/em&gt; that uses a simple average of the weighted returns and &lt;em&gt;weighted importance sampling&lt;/em&gt; that uses a weighted average. Let the time step $t$ span over episodes in such a way that if the first episode is 100 steps long, the next episode will start at $t=101$. Let $\tau(s,a)$ be the set of time steps where the state-action pair $(s,a)$ were first visited if using the first-visit method and the set of time steps for all visits if using the every-visit method. Let $T(t)$ be the first termination time step after $t$. We can then estimate $q_\pi$ using ordinary importance sampling by&lt;/p&gt;
&lt;div&gt;
$$
    Q(s,a) = \frac{\sum_{t \in \tau(s,a)} \rho_{t:T(t)-1} G_t}{|\tau(s,a)|}
$$
&lt;/div&gt;
&lt;p&gt;and using weighted importance sampling by&lt;/p&gt;
&lt;div&gt;
$$
    Q(s,a) = \frac{\sum_{t \in \tau(s,a)} \rho_{t:T(t)-1} G_t}{\sum_{t \in \tau(s,a)} \rho_{t:T(t)-1}}
$$
&lt;/div&gt;
&lt;p&gt;Ordinary importance sampling produces unbiased estimates, but has larger, possibly infinite variance. Weighted importance sampling produces biased estimates but has finite variance, which often makes it preffered in practice.&lt;/p&gt;
&lt;p&gt;Pseudocode for off-policy MC with weighted importance sampling:&lt;/p&gt;
&lt;div class=&#34;pseudocode-container&#34;&gt;
    &lt;b&gt;Off-policy MC with weighted importance sampling&lt;/b&gt;&lt;br&gt;
    &lt;div class=&#34;pseudocode&#34;&gt;
        Initialize:
        &lt;div class=&#34;pseudocode-indent&#34;&gt;
            $Q(s,a) \in \mathbb{R}$ arbitrarily for all $s \in \mathcal{S}$, $a \in \mathcal{A}(s)$&lt;br&gt;
            $C(s,a) \leftarrow 0$ for all $s \in \mathcal{S}$, $a \in \mathcal{A}(s)$&lt;br&gt;
            $\pi(s) \leftarrow argmax_a Q(s,a)$ (ties broken consistently)&lt;br&gt;
        &lt;/div&gt;
        Loop for multiple episodes:
        &lt;div class=&#34;pseudocode-indent&#34;&gt;
            $b \leftarrow$ any soft policy&lt;br&gt;
            Generate an episode following $b: S_0,A_0,R_1,\ldots,S_{T-1},A_{T-1},R_T$&lt;br&gt;
            $G \leftarrow 0$&lt;br&gt;
            $W \leftarrow 1$&lt;br&gt;
            Loop for each step of the episode, $t=T-1,T-2,\ldots,0$:
            &lt;div class=&#34;pseudocode-indent&#34;&gt;
                $G \leftarrow \gamma G + R_{t+1}$&lt;br&gt;
                $C(S_t,A_t) \leftarrow C(S_t,A_t) + W$&lt;br&gt;
                $Q(S_t,A_t) \leftarrow Q(S_t,A_t) + \frac{W}{C(S_t,A_t)} [G - Q(S_t,A_t)]$&lt;br&gt;
                $\pi(s) \leftarrow argmax_a Q(S_t,a)$ (ties broken consistently)&lt;br&gt;
                If $A_t \neq \pi(S_t)$: exit inner loop (proceed to next episode)&lt;br&gt;
                $W \leftarrow W \frac{1}{b(A_t|S_t)}$
            &lt;/div&gt;
        &lt;/div&gt;
    &lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;The above algorithm uses the every-visit method. Note that $W$ and $C$ are used to implement weighted importance sampling incrementally.&lt;/p&gt;
&lt;h2 id=&#34;blackjack-example&#34;&gt;Blackjack Example&lt;/h2&gt;
&lt;p&gt;The objective of blackjack is to obtain cards whose numerical sum is as big as possible without exceeding 21. All face cards count as 10 and aces count as 1 or 11. The game begins with two cards dealt to both dealer and player. The dealer has one of his cards face down. If the player has 21 immediately he wins, unless the dealer also has 21, in which case the game is a draw. If the player doesn&amp;rsquo;t have 21, he can choose between the actions &lt;em&gt;hit&lt;/em&gt; and &lt;em&gt;stick&lt;/em&gt;. If he chooses hit he draws a new card and if the sum is greater than 21 he goes bust and immediately loses. If it isn&amp;rsquo;t, he can once again choose between hit and stick. If he chooses stick, the action goes over to the dealer who plays the fixed strategy of sticking on any sum 17 or greater and hitting otherwise. If the dealer goes bust, the player wins. Otherwise, win, lose, or draw is determined by whose final sum is greater. This simplified version of blackjack has no splitting and cards are assumed to be picked from an infinite number or decks (or picked with replacement).&lt;/p&gt;
&lt;p&gt;Playing blackjack is naturally formulated as a finite MDP where each game is an episode. Rewards are 0 everwhere except when reaching terminal states, where +1 is given for a win, 0 for a draw, and -1 for a loss. No discounting is done, thus the terminal rewards are also the returns. The state is a combination of the player sum, if the player has a usable ace, and the dealer&amp;rsquo;s showing card. If the player sum is 11 or less it is always best to hit, thus the values for player sum we are interested in is 12-21. The player has a usable ace if he can count an ace as 11 without going bust. The dealer shows one card (ace-10). Thus there are 200 possible states. The available actions are hit or stick in every non-terminal state.&lt;/p&gt;
&lt;p&gt;Note how much easier it is to generate sample transitions than to fully define the dynamics function $p$. Below is an implementation of an MC agent that interacts with the blackjack environment to learn an optimal policy. The agent uses on-policy exploring starts with the every visit method. The final policy, which is probably optimal or close to optimal, is plotted after training.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;numpy&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;enum&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Enum&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;copy&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;copy&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;matplotlib&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;mpl&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;plt&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;class&lt;/span&gt; &lt;span class=&#34;nc&#34;&gt;Action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Enum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;HIT&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;STICK&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;class&lt;/span&gt; &lt;span class=&#34;nc&#34;&gt;State&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;player_sum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;usable_ace&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dealer_showing&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;player_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;player_sum&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;usable_ace&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;usable_ace&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dealer_showing&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dealer_showing&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;to_string&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;player_sum: &lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{}&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;, usable_ace: &lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{}&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;, dealer_showing: &lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{}&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;format&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;player_sum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;usable_ace&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dealer_showing&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;class&lt;/span&gt; &lt;span class=&#34;nc&#34;&gt;Environment&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;starting_state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;starting_state&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;draw_card&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;cards&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;arange&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;probabilities&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ones&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;9&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;/&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;13&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;4&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;/&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;13&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;choice&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;cards&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;p&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;probabilities&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;take_action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Action&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;HIT&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;c1&#34;&gt;# Get new card&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;card&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;draw_card&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;card&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;c1&#34;&gt;# Ace worth 1 or 11&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;player_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;lt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                    &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;player_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;11&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                    &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;usable_ace&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                    &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;player_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;player_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;card&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;c1&#34;&gt;# Check if player busted or needs to use an ace as 1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;player_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;21&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;usable_ace&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                    &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;player_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                    &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;usable_ace&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;c1&#34;&gt;# Player busted&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Player sticks, dealer now acts&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;dealer_usable_ace&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dealer_showing&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;dealer_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;11&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dealer_usable_ace&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dealer_showing&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;while&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dealer_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;17&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;c1&#34;&gt;# Dealer get new card&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;card&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;draw_card&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;card&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;c1&#34;&gt;# Ace worth 1 or 11&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dealer_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;lt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                    &lt;span class=&#34;n&#34;&gt;dealer_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;11&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                    &lt;span class=&#34;n&#34;&gt;dealer_usable_ace&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                    &lt;span class=&#34;n&#34;&gt;dealer_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;n&#34;&gt;dealer_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;card&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;c1&#34;&gt;# Check if dealer busted or needs to use an ace as 1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dealer_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;21&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dealer_usable_ace&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                    &lt;span class=&#34;n&#34;&gt;dealer_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                    &lt;span class=&#34;n&#34;&gt;dealer_usable_ace&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;c1&#34;&gt;# Dealer busted&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Round is over&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dealer_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;player_sum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;elif&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dealer_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;player_sum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;class&lt;/span&gt; &lt;span class=&#34;nc&#34;&gt;Agent&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# The policy is a deterministic mapping from state to action: 0 = HIT, 1 = STICK&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Init policy to only stick on 20 and 21&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;policy&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;policy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;8&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;action_values&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;num_visits&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;action_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;run_episode&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;starting_action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;history&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Generate episode&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;copy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;starting_action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;take_action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;starting_action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;while&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;is&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;usable_ace&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;player_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;12&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;k&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dealer_showing&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;policy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;k&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;copy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;take_action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Use episode history to update state-values&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;usable_ace&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;player_sum&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;12&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;k&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dealer_showing&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;l&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;value&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_visits&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;k&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;l&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;action_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;k&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;l&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;action_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;k&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;l&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_visits&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;k&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;l&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;policy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;k&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;action_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;k&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;action_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;k&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;vm&#34;&gt;__name__&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;NUM_EPISODES&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10000000&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;agent&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Agent&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Agent learning policy&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;episode&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;NUM_EPISODES&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;episode&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;%&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10000&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;STARTING EPISODE:&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;episode&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;random_starting_state&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;State&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;randint&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;12&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;high&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;22&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;randint&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;randint&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;high&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;11&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Environment&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random_starting_state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;random_starting_action&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;randint&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;agent&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run_episode&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;environment&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;random_starting_action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Plot final policy&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ticks&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;arange&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;xtick_labels&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;A&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;2&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;3&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;4&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;5&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;6&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;7&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;8&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;9&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;10&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ytick_labels&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;12&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;22&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;fig&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ax1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;ax2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;subplots&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ax1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;imshow&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;agent&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;policy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cmap&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mpl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;colors&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ListedColormap&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;red&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;blue&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;origin&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;lower&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ax1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_xticks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ticks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ax1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_xticklabels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ticks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ax1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_xlabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Dealer showing&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ax1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_yticks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ticks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ax1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_yticklabels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ytick_labels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ax1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_ylabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Player sum&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ax1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_title&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Usable ace&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;im2&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;ax2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;imshow&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;agent&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;policy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cmap&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mpl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;colors&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ListedColormap&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;red&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;blue&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;origin&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;lower&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ax2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_xticks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ticks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ax2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_xticklabels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ticks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ax2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_yticks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ticks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ax2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_yticklabels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ytick_labels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ax2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_title&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;No usable ace&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;fig&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;subplots_adjust&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;right&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.85&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;cbar_ax&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fig&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add_axes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.9&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.4&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.02&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;cbar&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fig&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;colorbar&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;im2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cax&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;cbar_ax&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;ticks&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;cbar&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ax&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_yticklabels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Hit&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;Stick&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;savefig&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;final_policy.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/monte_carlo_methods/final_policy.png&#34; alt=&#34;final_policy&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;This has been a short introduction to Monte Carlo methods in reinforcement learning. On-policy and off-policy methods were discussed and a simple on-policy MC agent was implemented. The blackjack example is a good starting point for implementing MC methods because of its simplicity.&lt;/p&gt;
&lt;p&gt;Thank you for reading and feel free to send me any questions.&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Dynamic Programming</title>
      <link>https://cfml.se/dynamic-programming/</link>
      <pubDate>Mon, 14 Oct 2019 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/dynamic-programming/</guid>
      <description>&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;Dynamic programming (DP) refers to a collection of methods in both mathematical optimization and computer programming. What all methods have in common is that they simplify a complicated problem by breaking it down into simpler sub-problems in a recursive manner. The focus of this article is going to be the role of DP in mathematical optimization and for solving finite &lt;a href=&#34;https://cfml.se/markov-decision-processes/&#34;&gt;Markov Decision Processes&lt;/a&gt; (MDPs) in particular. Later in the article there is going to be a segment about DP in computer programming.&lt;/p&gt;
&lt;p&gt;Bellman equations are the key components of DP in mathematical optimization. To read more about them and MDPs in general, I refer you to my previous article &lt;a href=&#34;https://cfml.se/markov-decision-processes/&#34;&gt;Markov Decision Processes&lt;/a&gt;. As a reminder, the Bellman equations for the state-value function $v_{\pi}$ and action-value function $q_{\pi}$ are:&lt;/p&gt;
&lt;div&gt;
$$
    v_{\pi}(s) = \sum_a \pi(a|s) \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma v_{\pi}(s^\prime) \big)
$$
&lt;/div&gt;
&lt;div&gt;
$$
    q_{\pi}(s,a) = \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma \sum_{a^\prime} \pi(a^\prime|s^\prime) q(a^\prime,s^\prime) \big)
$$
&lt;/div&gt;
&lt;p&gt;The Bellman optimality equations are the Bellman equations for the optimal state-value function $v_&lt;em&gt;$ and optimal action-value function $q_&lt;/em&gt;$:&lt;/p&gt;
&lt;div&gt;
$$
    v_*(s) = \max_a \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma v_*(s^\prime) \big)
$$
&lt;/div&gt;
&lt;div&gt;
$$
    q_*(s,a) = \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma \max_{a^\prime} q_*(s^\prime,a^\prime) \big)
$$
&lt;/div&gt;
&lt;p&gt;The equations are going to be used as update rules in iterative algorithms. Remember that it&amp;rsquo;s trivial to find an optimal policy when you know $v_&lt;em&gt;$ or $q_&lt;/em&gt;$.&lt;/p&gt;
&lt;p&gt;Note that DP algorithms have limited utility in reinforcement learning (RL) since they assume perfect knowledge of the dynamics of the MDP and are very computationally expensive. However DP is still useful for many other problems and provides a good foundation for understanding many RL-algorithms.&lt;/p&gt;
&lt;p&gt;The aim of this article is to be a short introduction to DP with a focus on solving MDPs. To read more about this subject, I recommend Reinforcement Learning: An Introduction (Sutton, Barto, 2018). The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/dynamic_programming&#34;&gt;https://github.com/CarlFredriksson/dynamic_programming&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;policy-evaluation&#34;&gt;Policy Evaluation&lt;/h2&gt;
&lt;p&gt;Policy evaluation is about computing estimates of value functions, for example the state-value function $v_{\pi}$ for a given policy $\pi$. To compute the exact solution, one could theoretically solve the system of linear Bellman equations. However, the computational complexity grows quickly with the size of the MDP and iterative methods are often more efficient. We can use the Bellman equation for $v_{\pi}$ as an update rule&lt;/p&gt;
&lt;div&gt;
$$
    v_{k+1}(s) = \sum_a \pi(a|s) \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma v_{k}(s^\prime) \big)
$$
&lt;/div&gt;
&lt;p&gt;Applied iteratively, this update rule will make $v_{k+1}$ become closer to $v_{\pi}$ and eventually converge as $k \to \infty$. We of course don&amp;rsquo;t have infinite time and instead will have to accept approximations and stop when the updates become satisfactorily small. The updates can be done by remembering all the old values $v_{k}$ or by sweeping over the state space, updating values in place. If sweeping, the updates will use a mix of new and old values. This is not strictly the update rule as defined above, but $v_{k+1}$ will still converge to $v_{\pi}$. Sweeping uses less memory and usually converges faster so it&amp;rsquo;s normally the version that is used in DP.&lt;/p&gt;
&lt;p&gt;The idea of updating estimates using other estimates is called bootstrapping and is used extensively in RL.&lt;/p&gt;
&lt;h2 id=&#34;policy-improvement&#34;&gt;Policy Improvement&lt;/h2&gt;
&lt;p&gt;To solve an MDP, our goal is to find an optimal policy. The next step after approximating $v_{\pi}$ using policy evaluation is to improve the current policy $\pi$. To create a new improved policy $\pi^\prime$, we go over each state and select actions such that expected return is maximized assuming $\pi$ being followed afterwards. In other words, we select actions greedily with respect to $q_\pi$:&lt;/p&gt;
&lt;div&gt;
$$
\begin{aligned}
    \pi^\prime(s) &amp;= \underset{a}{\operatorname{argmax}} q_\pi(s,a) \\
        &amp;= \underset{a}{\operatorname{argmax}} E[R_{t+1} + \gamma v_\pi(S_{t+1}) | S_t=s, A_t=a] \\
        &amp;= \underset{a}{\operatorname{argmax}} \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma v_\pi(s^\prime) \big)
\end{aligned}
$$
&lt;/div&gt;
&lt;p&gt;When selecting actions this way we know that the new policy $\pi^\prime$ is going to be as good as, or better than the old policy $\pi$. If $\pi^\prime$ is as good as, but not better than $\pi$ it means that $\pi$ and $\pi^\prime$ are both optimal. Like in policy evaluation, policy improvement is usually implemented as a sweep.&lt;/p&gt;
&lt;h2 id=&#34;policy-iteration&#34;&gt;Policy Iteration&lt;/h2&gt;
&lt;p&gt;Now that we have tools for evaluating and improving policies we simply have to put them together. Policy iteration refers to alternating policy evaluation and policy improvement for multiple iterations until a satisfactory policy is found. Unless an optimal policy has been found, each iteration will produce a strictly better policy. Policy iteration will converge to an optimal policy in finite time since finite MDPs have a finite number of possible deterministic policies. To speed up convergence, policy evaluation will start with the value estimates of the previous policy.&lt;/p&gt;
&lt;h2 id=&#34;value-iteration&#34;&gt;Value Iteration&lt;/h2&gt;
&lt;p&gt;There is a drawback to policy iteration in that the policy evaluation step can be computationally expensive. When doing policy evaluation iteratively, $v_{\pi}$ converges only in the limit and we will have to accept approximations. But the ultimate goal is not to compute $v_{\pi}$ as accurately as possible each policy iteration, it is to find an optimal policy. How close to the real state-value function do we actually need to get? For the gridworld in example 4.1 from (Sutton, Barto, 2018), only three iterations of policy evaluation on the random policy was needed before the result of the corresponding policy improvement wouldn&amp;rsquo;t change. This can be seen in figure 4.1 from (Sutton, Barto, 2018).&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/dynamic_programming/rl_an_introduction_fig_4_1.png&#34; alt=&#34;rl_an_introduction_fig_4_1&#34;&gt;&lt;/p&gt;
&lt;p&gt;This was only an example, but it holds in general that the policy evaluation step can be truncated without removing the convergence guarantees of policy iteration. An important special case is when policy evaluation is stopped after a single iteration. This is called value iteration. Instead of having two steps it can be combined into a single update rule&lt;/p&gt;
&lt;div&gt;
$$
    v_{k+1}(s) = \max_a \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma v_k(s^\prime) \big)
$$
&lt;/div&gt;
&lt;p&gt;This is almost the same update rule as policy evaluation, but using the Bellman optimality equation for $v_*$ instead. Value iteration is usually implemented as a sweep, with every sweep effectively being a combination of one policy evaluation sweep and one policy iteration sweep. Usually doing a few policy evaluation sweeps for every policy improvement sweep will lead to faster convergence. All policy iteration algorithms with a truncated policy evaluation step can be implemented as a sequence of sweeps, with some sweeps using policy evaluation updates and some using value iteration updates.&lt;/p&gt;
&lt;h2 id=&#34;generalized-policy-iteration&#34;&gt;Generalized Policy Iteration&lt;/h2&gt;
&lt;p&gt;Policy iteration consists of policy evaluation and policy improvement steps that are alternated. As mentioned, the number of iterations of policy evaluation before doing policy improvement can be varied. There are other methods of finding optimal, or approximately optimal, policies that similarly alternate policy evaluation and improvement but vary other aspects. One such aspect is the granularity of the updates, which means that some states could get multiple updates before others gets a single one. This is often the case for RL-methods where the agent doesn&amp;rsquo;t have complete knowledge of the dynamics function $p$ and instead rely on exploration and sample returns.&lt;/p&gt;
&lt;p&gt;The term &lt;em&gt;generalized policy iteration&lt;/em&gt; (GPI) refers the general idea of alternating policy evaluation and policy improvement. It is an important idea since almost all RL methods are a form of GPI.&lt;/p&gt;
&lt;h2 id=&#34;gridworld-example&#34;&gt;Gridworld Example&lt;/h2&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/dynamic_programming/rl_an_introduction_example_4_1.png&#34; alt=&#34;rl_an_introduction_example_4_1&#34;&gt;&lt;/p&gt;
&lt;p&gt;We will use example 4.1 from (Sutton, Barto, 2018). States $1,2,&amp;hellip;,14$ are non-terminal and state $15$ is terminal. State $15$ is in both the top-left and the bottom-right. Actions &amp;ldquo;up&amp;rdquo;, &amp;ldquo;down&amp;rdquo;, &amp;ldquo;right&amp;rdquo;, and &amp;ldquo;left&amp;rdquo; are available in every state. Actions deterministically move the agent in their respective direction except when the action would move the agent out of the gridworld. In that case the state remains unchanged. Actions in all non-terminal states give the agent a reward of -1. It is an undiscounted episodic task.&lt;/p&gt;
&lt;p&gt;Since the dynamics are completely deterministic we can implement them as a lookup table $(s,a) \to (s^\prime,r)$. The actions $(0,1,2,3)$ correspond to &amp;ldquo;up&amp;rdquo;, &amp;ldquo;down&amp;rdquo;, &amp;ldquo;right&amp;rdquo;, and &amp;ldquo;left&amp;rdquo; respectively. The state numbers have all been reduced by one to make them start at zero which makes the policy evaluation and value iteration functions cleaner.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;DYNAMICS&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;4&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;14&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;o&#34;&gt;...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;14&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;14&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;14&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;14&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;14&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;14&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;14&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;14&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;p&#34;&gt;}&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;A policy evaluation sweep can be implemented as follows:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;policy_evaluation_sweep&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dynamics&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;policy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;num_states&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_states&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;value&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;num_actions&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;policy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_actions&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;new_state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dynamics&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;value&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;policy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;new_state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;state_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;value&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;I also implemented a non-sweeping version for comparison with figure 4.1 from (Sutton, Barto, 2018), since they used a non-sweeping version when producing the results in that figure.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;policy_evaluation&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dynamics&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;policy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;num_states&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;new_state_values&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_states&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_states&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;value&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;num_actions&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;policy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_actions&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;new_state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dynamics&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;value&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;policy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;new_state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;new_state_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;value&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;new_state_values&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Policy improvement can be implemented as follows:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;value_iteration_sweep&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dynamics&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;policy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;num_states&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_states&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;max_value&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;inf&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;num_actions&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;policy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_actions&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;new_state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dynamics&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;value&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reward&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;state_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;new_state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;value&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;max_value&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;n&#34;&gt;max_value&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;value&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;n&#34;&gt;policy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;state&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;eye&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_actions&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;action&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The policy was initialized to the random policy. When running some iterations of non-sweeping policy evaluation the results matched those in figure 4.1. Policy improvement applied to the state-values after at least three iterations of non-sweeping policy evaluation resulted in an optimal policy. The sweeping policy evaluation needed a few more iterations before policy improvement would find an optimal policy.&lt;/p&gt;
&lt;h2 id=&#34;dp-in-computer-programming&#34;&gt;DP in Computer Programming&lt;/h2&gt;
&lt;p&gt;As mentioned in the introduction, DP refers to a collection of methods in both mathematical optimization and computer programming. For DP to be applicable to a programming problem, the problem needs to have optimal substructure and overlapping sub-problems.&lt;/p&gt;
&lt;p&gt;Optimal substructure means that the solution to the problem can be constructed from optimal solutions to its sub-problems. Optimal substructures are usually described by recursion. An example of this is shortest path problems. If there is a path $A \to B \to C \to D$ that is the shortest path from $A$ to $D$, then $B \to C \to D$ is the shortest path from $B$ to $D$. In other words, the shortest path problem from $B$ to $D$ is nested inside the $A$ to $D$ problem and can be called a sub-problem.&lt;/p&gt;
&lt;p&gt;Overlapping sub-problems means that a simple recursive algorithm will solve the same sub-problems over and over. If we can solve a problem by combining optimal solutions to non-overlapping sub-problems, DP is not applicable to that problem and the strategy divide and conquer might be used instead. Examples of divide and conquer algorithms are merge sort and quick sort.&lt;/p&gt;
&lt;p&gt;DP has two approaches for avoiding having to solve the same sub-problems over and over. These are the top-down approach and the bottom-up approach. The top-down approach is to do recursion in the order the problem is formulated, but also store the result of solved sub-problems so that they don&amp;rsquo;t need to be computed more than once. Storing solutions to sub-problems is called memoization. The bottom-up approach is to solve sub-problems from the bottom first and using those solutions to build solutions to bigger sub-problems.&lt;/p&gt;
&lt;h3 id=&#34;fibonacci-sequence-example&#34;&gt;Fibonacci Sequence Example&lt;/h3&gt;
&lt;p&gt;Computing the &lt;em&gt;n&lt;/em&gt;th member of the Fibonacci sequence is a good example of a problem that has both optimal substructure and overlapping sub-problems. The Fibonacci sequence $F_n$ is defined as:&lt;/p&gt;
&lt;div&gt;
$$
\begin{aligned}
    F_0 &amp;= 0 \\
    F_1 &amp;= 1 \\
    F_n &amp;= F_{n-1} + F_{n-2} 
\end{aligned}
$$
&lt;/div&gt;
&lt;p&gt;The following is a naive implementation:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;fib&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;lt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fib&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fib&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The naive implementation is computationally very inefficient since it computes the solutions to the same sub-problems over and over. For example, if we call &lt;code&gt;fib(5)&lt;/code&gt; we produce the following call tree:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;fib(5)&lt;/li&gt;
&lt;li&gt;fib(4) + fib(3)&lt;/li&gt;
&lt;li&gt;(fib(3) + fib(2)) + (fib(2) + fib(1))&lt;/li&gt;
&lt;li&gt;((fib(2) + fib(1)) + (fib(1) + fib(0))) + ((fib(1) + fib(0)) + fib(1))&lt;/li&gt;
&lt;li&gt;(((fib(1) + fib(0)) + fib(1)) + (fib(1) + fib(0))) + ((fib(1) + fib(0)) + fib(1))&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;You can see that &lt;code&gt;fib(2)&lt;/code&gt; was computed from scratch three times. This implementation runs in exponential time (O(2^n)).&lt;/p&gt;
&lt;p&gt;A DP implementation with the top-down approach:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;memory&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;{&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;fib_memoization&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;not&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;memory&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;memory&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fib_memoization&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fib_memoization&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;memory&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;This version uses memoization and computes solutions to sub-problems only if they aren&amp;rsquo;t in the memory yet. This implementation runs in linear time (O(&lt;em&gt;n&lt;/em&gt;)) but requires O(&lt;em&gt;n&lt;/em&gt;) storage space.&lt;/p&gt;
&lt;p&gt;A DP implementation with the bottom-up approach:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;fib_bottom_up&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;previous_fib&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;current_fib&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;_&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;new_fib&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;previous_fib&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;current_fib&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;previous_fib&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;current_fib&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;current_fib&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;new_fib&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;current_fib&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;This version computes sub-problems from the bottom and builds upwards. The smaller values are computed first and are used to compute larger values. This implementation also runs in linear time but only requires constant (O(1)) storage space.&lt;/p&gt;
&lt;p&gt;Both DP implementations actually take O(n^2) time for large integers, since addition with large integers is O(n).&lt;/p&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;This has been a short introduction to dynamic programming. Splitting up a complex problem (for example computing an optimal policy in an MDP) into simpler, interdependent sub-problems (for example computing state-values) is a powerful idea that can be applied to a variety of problems. Understanding DP and the idea of GPI provides a good base for understanding many reinforcement learning algorithms.&lt;/p&gt;
&lt;p&gt;Thank you for reading and feel free to send me any questions.&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Markov Decision Processes</title>
      <link>https://cfml.se/markov-decision-processes/</link>
      <pubDate>Tue, 08 Oct 2019 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/markov-decision-processes/</guid>
      <description>&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;A Markov decision process (MDP) is a mathematical formalization of a sequential decision-making process where actions affect both immediate reward and the next state. MDPs are similar to &lt;a href=&#34;https://cfml.se/multi-armed-bandits/&#34;&gt;Multi-armed Bandits&lt;/a&gt; in that the agent repeatedly has to make decisions and receives immediate rewards depending on what action is selected. What makes MDPs more complex is that the actions affect the subsequent states, and thus affect future rewards as well. The goal of the agent is the same, to maximize expected return (expected cumulative future reward). In an MDP, each state could have different actions available. After an agent selects an action, the environment gives the agent a reward and presents a new state. Figure 3.1 from Reinforcement Learning: An Introduction (Sutton, Barto, 2018) illustrates the interaction between agent and environment.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/markov_decision_processes/rl_an_introduction_fig_3_1.png&#34; alt=&#34;rl_an_introduction_fig_3_1&#34;&gt;&lt;/p&gt;
&lt;p&gt;At each time step $t=0,1,2,&amp;hellip;$, the agent observes a representation of the environment&amp;rsquo;s state $S_t \in \mathcal{S}$, where $\mathcal{S}$ is the set of all states that aren&amp;rsquo;t terminal states (the set $\mathcal{S}^+$ includes all terminal states). The agent selects an action $A_t \in \mathcal{A}(S_t)$, where $\mathcal{A}(S_t)$ is the set of all available actions in state $S_t$. The next time step, the agent receives a numerical reward $R_{t+1} \in \mathcal{R} \subset \mathbb{R}$ that depends on the previous state and selected action. The agent observes a new state $S_{t+1}$ and the process continues. This process creates a sequence (or trajectory) of the following form:&lt;/p&gt;
&lt;p&gt;$$
S_0,A_0,R_1,S_1,A_1,R_2,S_2,A_2,R_3,&amp;hellip;
$$&lt;/p&gt;
&lt;h3 id=&#34;dynamics-function-p&#34;&gt;Dynamics Function $p$&lt;/h3&gt;
&lt;p&gt;This article is going to focus on finite MDPs, which are MDPs where the sets $\mathcal{S}, \mathcal{A}, \mathcal{R}$ all have a finite number of elements. Finite MDPs are defined by their dynamics function $p$:&lt;/p&gt;
&lt;p&gt;$$
p(s^\prime,r|s,a) = Pr{S_t=s^\prime, R_t=r | S_{t-1}=s, A_t=a}
$$&lt;/p&gt;
&lt;p&gt;In other words, $p$ defines the joint probability of observing state $s^\prime$ and receiving reward $r$ after being in state $s$ and taking action $a$.&lt;/p&gt;
&lt;h3 id=&#34;episodic-and-continuous-tasks&#34;&gt;Episodic and Continuous Tasks&lt;/h3&gt;
&lt;p&gt;MDPs can be divided into episodic and continuous tasks. Episodic tasks have a clear ending, the agent will eventually reach a terminal state. Continuous tasks have no clear ending and potentially continue forever. In both type of tasks the goal is to maximize expected return. However, the return can be defined in different ways.&lt;/p&gt;
&lt;p&gt;The return is denoted $G_t$ and the final time step is denoted $T$. In episodic tasks, $T$ is finite and $G_t$ can be defined as:&lt;/p&gt;
&lt;p&gt;$$
G_t = R_{t+1} + R_{t+2} + R_{t+3} + &amp;hellip; + R_{T} = \sum_{k=t+1}^{T} R_k
$$&lt;/p&gt;
&lt;p&gt;In continuous tasks, $T = \infty$ and $G_t$ could easily diverge. To make $G_t$ converge we add add a parameter $0 \leq \gamma \leq 1$, called the discount rate:&lt;/p&gt;
&lt;p&gt;$$
G_t = R_{t+1} + \gamma R_{t+2} + \gamma^2 R_{t+3} + &amp;hellip; = \sum_{k=0}^{\infty} \gamma^k R_{t+k+1}
$$&lt;/p&gt;
&lt;p&gt;This is called discounting and means that rewards closer to the current time step will be prioritized. We can also define the return in a way that works for both episodic and continuous tasks:&lt;/p&gt;
&lt;p&gt;$$
G_t = \sum_{k=t+1}^{T} \gamma^{k-t-1} R_k
$$&lt;/p&gt;
&lt;p&gt;When the task is episodic we usually set $\gamma = 1$ (sometimes it might make sense to use discounting for an episodic task). When the task is continuous we usually set $\gamma &amp;lt; 1$.&lt;/p&gt;
&lt;h3 id=&#34;examples&#34;&gt;Examples&lt;/h3&gt;
&lt;p&gt;A wide array of problems can be formulated as MDPs. Some examples:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;Chess&lt;/strong&gt;: The states are the position of the pieces. Rewards are 0 except when reaching a terminal state (win, draw, loss). A reward of 1 is given for a win, 1/2 for a draw, and 0 for a loss. The actions are all legal moves in a given state and they deterministically change the state. Since there are terminal states that eventually will be reached it is an episodic task.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Cleaning robot&lt;/strong&gt;: The states are combinations of the robots sensor readings, such as position and dust-detection sensors. Positive rewards are given when dust is cleaned up. You might also add a negative reward when the robot runs out of battery, to incentivize the robot to go to its recharging station before running out. The actions could be on the abstraction level of activating motors, or higher level like &amp;ldquo;move right&amp;rdquo;. Could be modelled either as an episodic or continuous task, depending on if you want the robot to run all the time or not.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Temperature control&lt;/strong&gt;: The states are temperature readings on a thermometer. Negative reward is given when a human manually has to adjust the temperature. The actions are how much power to send to different radiators. Will probably be a continuous task.&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;Note the importance of designing good rewards, such that if the agent learns to optimize expected return, it will also perform the task very well.&lt;/p&gt;
&lt;h3 id=&#34;solving-mdps&#34;&gt;Solving MDPs&lt;/h3&gt;
&lt;p&gt;The objective of the agent is to maximize expected cumulative future reward. In other words, we want to find a policy that gives the agent as much expected return as possible. A policy is a mapping from states to probabilities of selecting each available action in that state. Solving an MDP usually means finding an optimal policy. When MDPs are small (the sets of states and actions are small) and the dynamics function $p$ is fully known, we can find the optimal policy by solving the Bellman equations, more on this later. MDPs are often too big for this to be computationally feasible. There is a collection of algorithms called dynamic programming (DP) that can be more efficient at finding the optimal policy. When the MDPs are too large for DP or when the dynamics function is not explicitly known (which is often the case for real world appplications), we use reinforcement learning (RL) instead. In RL the aim is usually to find a policy that is good but doesn&amp;rsquo;t have to be optimal.&lt;/p&gt;
&lt;p&gt;The aim of this article is to introduce the basics of finite MDPs. To read more about this subject, I recommend the aforementioned (Sutton, Barto, 2018).&lt;/p&gt;
&lt;h2 id=&#34;policies-and-value-functions&#34;&gt;Policies and Value Functions&lt;/h2&gt;
&lt;p&gt;As mentioned above, a policy is a mapping from states to probabilities of selecting each available action in that state. More formally, if the agent is following policy $\pi$ at time $t$, then $\pi(a|s)$ is the probability that $A_t=a$ if $S_t=s$. A special case is deterministic policies where we denote the action that will be selected in state $s$ by $\pi(s)$.&lt;/p&gt;
&lt;p&gt;Value functions are functions that map states, or state-action pairs, to expected return when following a specific policy $\pi$. You can think of value functions as how good a state, or state-action pair, is when following $\pi$. Value functions often can&amp;rsquo;t be computed exactly, but estimating them and using the estimates to make policy decisions is a part of almost all RL algorithms. There are two types of value functions, state-value functions and action-value functions.&lt;/p&gt;
&lt;p&gt;The state-value function of a state $s$ when following policy $\pi$, denoted $v_{\pi}(s)$, is the expected return when following $\pi$ from $s$. More formally:&lt;/p&gt;
&lt;p&gt;$$
v_{\pi}(s) = E_{\pi}[G_t | S_t=s] = E_{\pi}\Big[\sum_{k=t+1}^{T} \gamma^{k-t-1} R_k \Big| S_t=s \Big]
$$&lt;/p&gt;
&lt;p&gt;The action-value function of a state $s$ and action $a$ when following policy $\pi$, denoted $q_{\pi}(s,a)$, is the the expected return when taking action $a$ in $s$, and then following $\pi$ from the new state. More formally:&lt;/p&gt;
&lt;p&gt;$$
q_{\pi}(s,a) = E_{\pi}[G_t | S_t=s, A_t=t] = E_{\pi}\Big[\sum_{k=t+1}^{T} \gamma^{k-t-1} R_k \Big| S_t=s, A_t=a \Big]
$$&lt;/p&gt;
&lt;p&gt;A policy $\pi$ is better than or equal to policy $\pi^\prime$ if its expected return is greater than or equal to that of $\pi^\prime$ for all states. More formally, $\pi \geq \pi^\prime$ if and only if $v_{\pi}(s) \geq v_{\pi^\prime}(s)$ for all states $s$. There is always at least one optimal policy, and we denote all of them $\pi_*$. Even though there can be multiple optimal policies, they all share the same optimal state-value function, defined as $v_*(s)=\max_{\pi} v_{\pi}(s)=v_{\pi_*}(s)$. They also share the same optimal action-value function, defined as $q_*(s,a)=\max_{\pi} q_{\pi}(s,a)=q_{\pi_*}(s,a)$.&lt;/p&gt;
&lt;h3 id=&#34;bellman-equations&#34;&gt;Bellman Equations&lt;/h3&gt;
&lt;p&gt;An important property of value functions that is used throughout RL and DP is that they are recursive. We can use this property to define Bellman equations. Lets&amp;rsquo; start with the Bellman equation for $v_{\pi}$:&lt;/p&gt;
&lt;p&gt;$$
v_{\pi}(s) = \sum_a \pi(a|s) \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma v_{\pi}(s^\prime) \big)
$$&lt;/p&gt;
&lt;p&gt;Derivation:&lt;/p&gt;
&lt;div&gt;
$$
\begin{aligned}
    v_{\pi}(s) &amp;= E_{\pi}[G_t | S_t=s] \\
        &amp;= E_{\pi}[R_{t+1} + \gamma G_{t+1} | S_t=s] \\
        &amp;= \sum_a \pi(a|s) \sum_{s^\prime} \sum_{r} p(s^\prime,r|s,a) \big(r + \gamma E_{\pi}[G_{t+1} | S_{t+1}=s^\prime] \big) \\
        &amp;= \sum_a \pi(a|s) \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma v_{\pi}(s^\prime) \big)
\end{aligned}
$$
&lt;/div&gt;
&lt;p&gt;Let&amp;rsquo;s continue with the Bellman equation for $q_{\pi}$:&lt;/p&gt;
&lt;p&gt;$$
q_{\pi}(s,a) = \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma \sum_{a^\prime} \pi(a^\prime|s^\prime) q(a^\prime,s^\prime) \big)
$$&lt;/p&gt;
&lt;p&gt;Derivation:&lt;/p&gt;
&lt;div&gt;
$$
\begin{aligned}
    q_{\pi}(s,a) &amp;= E_{\pi}[G_t | S_t=s, A_t=a] \\
        &amp;= E_{\pi}[R_{t+1} + \gamma G_{t+1} | S_t=s, A_t=a] \\
        &amp;= \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma E_{\pi}[G_{t+1} | S_{t+1}=s^\prime] \big) \\
        &amp;= \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma v_{\pi}(s^\prime)] \big) \\
        &amp;= \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma \sum_{a^\prime} \pi(a^\prime|s^\prime) q(a^\prime,s^\prime) \big)
\end{aligned}
$$
&lt;/div&gt;
&lt;p&gt;We can do the same thing for the optimal value functions. These are sometimes called Bellman optimality equations. Let&amp;rsquo;s start with the Bellman equation for $v_*$:&lt;/p&gt;
&lt;div&gt;
$$
v_*(s) = \max_a \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma v_*(s^\prime) \big)
$$
&lt;/div&gt;
&lt;p&gt;Derivation:&lt;/p&gt;
&lt;div&gt;
$$
\begin{aligned}
    v_*(s) &amp;= v_{\pi_*}(s) \\
        &amp;= \max_a q_{\pi_*}(s,a) \\
        &amp;= \max_a E_{\pi_*}[G_t | S_t=s, A_t=a] \\
        &amp;= \max_a E_{\pi_*}[R_{t+1} + \gamma G_{t+1} | S_t=s, A_t=a] \\
        &amp;= \max_a E[R_{t+1} + \gamma v_*(S_{t+1}) | S_t=s, A_t=a] \\
        &amp;= \max_a \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma v_*(s^\prime) \big)
\end{aligned}
$$
&lt;/div&gt;
&lt;p&gt;Let&amp;rsquo;s continue with the Bellman equation for $q_*$:&lt;/p&gt;
&lt;div&gt;
$$
q_*(s,a) = \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma \max_{a^\prime} q_*(s^\prime,a^\prime) \big)
$$
&lt;/div&gt;
&lt;p&gt;Derivation:&lt;/p&gt;
&lt;div&gt;
$$
\begin{aligned}
    q_*(s,a) &amp;= E[R_{t+1} + \gamma \max_{a^\prime} q_*(S_{t+1},a^\prime) | S_t=s, A_t=a] \\
        &amp;= \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma \max_{a^\prime} q_*(s^\prime,a^\prime) \big)
\end{aligned}
$$
&lt;/div&gt;
&lt;p&gt;The equations must hold for all states, thus we have systems of linear equations we can solve to compute value functions. In the optimal case it is a bit trickier since they involve the non-linear operation &lt;em&gt;max&lt;/em&gt;. The systems of non-linear equations can still be solved to compute the optimal value functions. Even though it is theoretically possible to solve the Bellman equations for all finite MDPs where the dynamics function is known, it becomes very computationally expensive as MDPs get larger.&lt;/p&gt;
&lt;p&gt;If you know an optimal value function, it&amp;rsquo;s trivial to compute the optimal policy. If you have the optimal state-value function, simply look one step ahead in every state and select actions greedily. In other words, select actions by:&lt;/p&gt;
&lt;div&gt;
$$
    \underset{a}{\operatorname{argmax}} \sum_{s^\prime,r} p(s^\prime,r|s,a) \big(r + \gamma v_*(s^\prime) \big)
$$
&lt;/div&gt;
&lt;p&gt;If you have the optimal action-value function it is even easier. In all states, simply select the action with the highest action-value function:&lt;/p&gt;
&lt;div&gt;
$$
    \underset{a}{\operatorname{argmax}} q_*(s,a)
$$
&lt;/div&gt;
&lt;p&gt;If there are ties between actions in either case, you can select any probability distribution between the tying actions.&lt;/p&gt;
&lt;h2 id=&#34;gridworld-example&#34;&gt;Gridworld Example&lt;/h2&gt;
&lt;p&gt;Gridworlds are often used as examples of MDPs. Example 3.5 from (Sutton, Barto, 2018) is a simple example of a finite MDP. The left part of figure 3.2 from (Sutton, Barto, 2018) shows the gridworld. Each cell is a state and the agent can select between the actions: north, south, east, and west in all states. The actions deterministically move the agent in their respective directions, except when they would move the agent out of the grid. In that case the state will remain unchanged and the agent will receive a reward of -1. When the agent selects any action in state $A$, it will be moved to state $A^\prime$ and receive a reward of +10. When the agent selects any action in state $B$, it will be moved to state $B^\prime$ and receive a reward of +5. In all other cases the agent will receive a reward of 0.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/markov_decision_processes/rl_an_introduction_fig_3_2.png&#34; alt=&#34;rl_an_introduction_fig_3_2&#34;&gt;&lt;/p&gt;
&lt;p&gt;The right part of figure 3.2 shows the state-value function $v_{\pi}$ for the policy where all actions are selected at random with equal probability in all states. Since there is no terminal state, discounting was used with rate $\gamma=0.9$. The authors of the book solved the Bellman equations to compute the value function.&lt;/p&gt;
&lt;p&gt;The authors also solved the Bellman equations for $v_*$, and the optimal state-value function can be seen in the middle of figure 3.5 from (Sutton, Barto, 2018). The right part shows the corresponding optimal policies $\pi_*$. When there are multiple arrow-directions in a state, it means that any of those directions is an optimal action in that state.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/markov_decision_processes/rl_an_introduction_fig_3_5.png&#34; alt=&#34;rl_an_introduction_fig_3_5&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;This has been a short introduction to MDPs. The formalization is pretty simple but can model a great deal of problems and is fundamental to the majority of reinforcement learning.&lt;/p&gt;
&lt;p&gt;Thank you for reading and feel free to send me any questions.&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Multi-armed Bandits</title>
      <link>https://cfml.se/multi-armed-bandits/</link>
      <pubDate>Thu, 03 Oct 2019 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/multi-armed-bandits/</guid>
      <description>&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/multi_armed_bandit/multi_armed_bandit.png&#34; alt=&#34;multi_armed_bandit&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;A multi-armed bandit problem is a problem where an agent repeatedly has to make a choice between actions that give the agent rewards. Each action has a different probability distribution that determines how much reward is given, and these distributions are not known to the agent. The objective of the agent is to maximize expected cumulative future reward, which is called expected return. The name multi-armed bandit comes from slot machines (one-armed bandits) and one can imagine the actions the agent takes as choosing which arm to pull on a slot machine with multiple arms.&lt;/p&gt;
&lt;p&gt;Let&amp;rsquo;s say we have a $k$-armed bandit problem, with $k$ actions $a$. We can denote the action selected at time $t$ as $A_t$, and the reward received as $R_t$. The value of an action, $q_*(a)$, is defined as the expected reward given that $a$ is selected:&lt;/p&gt;
&lt;p&gt;$$
q_*(a) = E[R_t|A_t=a]
$$&lt;/p&gt;
&lt;p&gt;If the agent knew the actual value for each action, it would be trivial to solve the problem by simply selecting the action with the highest value all the time. Instead the agent has to estimate the value of each action. We denote the estimate of the value of selecting action $a$ at time $t$ as $Q_t(a)$. If the agent&amp;rsquo;s value estimates are perfect, it can make perfect decisions, so we want $Q_t(a)$ to be as close to $q_*(a)$ as possible.&lt;/p&gt;
&lt;p&gt;A bandit problem is a simple version of Reinforcement Learning (RL) problems. In full RL problems, the agent similarly want to maximize expected return, but the selected actions also affect the state of the agent&amp;rsquo;s environment which affects future rewards and makes the problem more complicated. Even though bandit problems are simpler, studying them provides a good base for understanding many aspects of RL.&lt;/p&gt;
&lt;p&gt;There is a type of bandit problem called contextual bandits where state has to be taken into account. For example you can imagine a problem with multiple $k$-armed bandits and each time step the agent is presented with a random one among them. The agent also receives some clue about what the current bandit is and can learn a policy. A policy is a mapping from states to probability distributions over actions that determines what action to take in any given state. Contextual bandits are closer to full RL problems since they have states, but are still simpler since actions does not affect future states.&lt;/p&gt;
&lt;p&gt;The aim of this article is to introduce the basics of multiple-armed bandit problems. To read more about this subject, I recommend Reinforcement Learning: An Introduction (Sutton, Barto, 2018). The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/multi_armed_bandits&#34;&gt;https://github.com/CarlFredriksson/multi_armed_bandits&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;value-estimation&#34;&gt;Value Estimation&lt;/h2&gt;
&lt;p&gt;To estimate the value of each action we can set our estimates to an initial value $Q_0$ and then update the estimate for an action every time the agent takes that action. The general update rule is as follows:&lt;/p&gt;
&lt;p&gt;$$
Q_{t+1}(a) = Q_t(a) + \alpha (R_t - Q_t(a))
$$&lt;/p&gt;
&lt;p&gt;This update rule brings the estimate closer to the most recently observed reward $R_t$. How much the estimate should move towards the reward depends on $\alpha$, which is called the step size. The step size can be a constant or be computed in various ways. One of the more natural ways to compute $\alpha$ is to set it to $1/N(a)$, where $N(a)$ is the number of times action $a$ has been selected. This step size means that the value estimates will be the average reward received when taking that action. To see that this is the case we can let $Q_{n+1}(a)$ be the average of all received rewards when selecting action $a$ and manipulate the expression. For brevity, assume that we are looking at a single action and drop the $a$. Let $Q_{n}=Q_{n}(a)$, $n=N(a)$, and $R_n$ be the reward received when taking the action the $n$-th time:&lt;/p&gt;
&lt;div&gt;
$$
\begin{aligned}
    Q_{n+1} &amp;= \frac{1}{n} \sum_{i=1}^{n} R_i \\
        &amp;= \frac{1}{n} \Big(R_{n} + \sum_{i=1}^{n-1} R_i \Big) \\
        &amp;= \frac{1}{n} \Big(R_n + (n-1) \frac{1}{n-1} \sum_{i=1}^{n-1} R_i \Big) \\
        &amp;= \frac{1}{n} \Big(R_n + (n-1) Q_n \Big) \\
        &amp;= \frac{1}{n} \Big(R_n + n Q_n - Q_n \Big) \\
        &amp;= Q_n + \frac{1}{n} \Big(R_n - Q_n \Big)
\end{aligned}
$$
&lt;/div&gt;
&lt;p&gt;Thus with $\alpha = 1/N(a)$, the estimate $Q_{t}(a)$ will be the average of all received rewards and will converge to the real value $q_*(a)$ by the law of large numbers. This step size is very good when the problem is stationary, in other words, when $q_*(a)$ stays the same for all $t$. However, when the problem is non-stationary and $q_*(a)$ changes, $\alpha = 1/N(a)$ might not lead to good estimates in the long term. This is because the observed rewards affect the estimates less and less the more times actions gets selected. In other words, $Q_{t}(a)$ will change less as time goes on. When the problem is non-stationary it is more appropriate to use a constant $\alpha \in (0,1]$. This leads to recent rewards affecting the estimates more than older ones, with the extreme $\alpha = 1$ leading to estimates only caring about the most recent reward. Thus a constant $\alpha$ will make the agent more adaptable to changing $q_*(a)$.&lt;/p&gt;
&lt;h2 id=&#34;action-selection&#34;&gt;Action Selection&lt;/h2&gt;
&lt;p&gt;Now we have discussed a couple of options for estimating the value of actions, but how should the agent select actions? The first thing that comes to mind is probably to be greedy and simply select the action with the highest estimated value. In other words, the agent exploits its current knowledge in order to get as much immediate reward as possible. This is an important part of action selection since the objective is to maximize expected return. However, if an agent is too greedy and only exploits, it might not learn accurate value estimates. If an agent has inaccurate value estimates, it can select a sub-optimal action when it exploits. This is why the opposite of exploitation, exploration, is also important when selecting actions.&lt;/p&gt;
&lt;p&gt;Exploration is the act of sacrificing immediate reward for gaining knowledge about the problem, in this case to get better value estimates. There is an important tradeoff in RL (and many other fields) called exploration versus exploitation. This is a tradeoff because an agent can&amp;rsquo;t do both at the same time. The tradeoff needs to be balanced, since if the agent exploits too much it might take a long time to learn what the best action is, or it might never learn it. On the other hand, if the agent explores too much it will miss out on a lot of reward from selecting sub-optimal actions.&lt;/p&gt;
&lt;p&gt;There are many methods for balancing the tradeoff. One of the simpler ones is $\epsilon$-greedy, which means that the agent is greedy with probability $1-\epsilon$, and selects a random action with probability $\epsilon$.&lt;/p&gt;
&lt;h2 id=&#34;action-value-methods&#34;&gt;Action-value Methods&lt;/h2&gt;
&lt;p&gt;When a method for value estimation is combined with a method for action selection it is called an action-value method. To roughly assess the performance of different action-value methods we can use a set of randomly generated $k$-armed bandit problems.&lt;/p&gt;
&lt;p&gt;The code below generates multiple bandit problems with different values $q_*(a)$. It runs $\epsilon$-greedy and greedy only methods ($\epsilon$-greedy with $\epsilon=0$) with different parameter settings on the generated problems. The results are averaged over the generated problems and plots of both reward over time and percentage of optimal action selected over time are saved. The same procedure is also done for the non-stationary case, when $q_*(a)$ changes over time.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;numpy&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;plt&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;math&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;my_argmax&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;s2&#34;&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;s2&#34;&gt;    Returns the index of the maximum value in Q_values.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;s2&#34;&gt;    In case of ties, one of the tying indices are selected at uniform-random and returned.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;s2&#34;&gt;    &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;max_val&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;math&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;inf&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ties&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Q_val&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Q_values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Q_val&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;max_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;max_val&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Q_val&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;ties&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;elif&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Q_val&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;max_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;ties&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;choice&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ties&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;run_epsilon_greedy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_steps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;const_alpha&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Q_0&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;stationary&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;s2&#34;&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;s2&#34;&gt;    Run a single run of the epsilon greedy algorithm.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;s2&#34;&gt;    Returns the returns and an array where 1 means optimal action was selected and 0 it wasn&amp;#39;t.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;s2&#34;&gt;    &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;R&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_steps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;optimal&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_steps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ones&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Q_0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;N&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Run steps&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;t&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_steps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;not&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;stationary&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;q&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;q&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;normal&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;scale&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.05&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Choose action greedily by default and randomly by probability epsilon&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;a&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;my_argmax&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;uniform&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;low&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;high&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;1.0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;a&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;randint&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;low&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;high&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;N&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;argmax&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;optimal&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;t&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Get reward&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;R&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;t&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;normal&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;loc&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Update Q&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;alpha&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;const_alpha&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;const_alpha&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;is&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;alpha&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;N&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;alpha&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;R&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;t&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;R&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;optimal&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;run_experiments&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_runs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_steps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;k&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;stationary&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;s2&#34;&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;s2&#34;&gt;    Run multiple runs of epsilon greedy.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;s2&#34;&gt;    Returns the averaged results and optimal action selections.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;s2&#34;&gt;    &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;bandits&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;normal&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;loc&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;scale&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_runs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;k&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;R&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_runs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_steps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;optimal&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;R&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_runs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;%&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;100&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Parameter set&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;- Starting run&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;q&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;copy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;bandits&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;][&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;epsilon&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;const_alpha&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;][&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;const_alpha&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;Q_0&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;][&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Q_0&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;R&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;optimal&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;run_epsilon_greedy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;q&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_steps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;const_alpha&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Q_0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;stationary&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mean&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;R&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mean&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;optimal&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;plot_subplot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;y_label&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;title&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;s2&#34;&gt;&amp;#34;&amp;#34;&amp;#34;Plot a subplot of results or optimal action selections.&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;num_plots&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;arange&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_plots&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;][&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;epsilon&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;const_alpha&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;][&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;const_alpha&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Q_0&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;][&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Q_0&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;alpha&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;const_alpha&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;const_alpha&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;is&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;not&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;1/n&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;label&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;s2&#34;&gt;&amp;#34;epsilon=&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;s2&#34;&gt;&amp;#34;, alpha=&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;alpha&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;s2&#34;&gt;&amp;#34;, Q_0=&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Q_0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;label&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;label&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;legend&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;xlabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Step&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ylabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;y_label&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;title&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;is&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;not&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;title&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;title&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;plot_results&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;file_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;R_avg&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;optimal_avg&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;title&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;s2&#34;&gt;&amp;#34;&amp;#34;&amp;#34;Plot the averaged results and optimal action selections.&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;figure&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;figsize&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;8&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;subplot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plot_subplot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;R_avg&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;Average reward&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;title&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;subplot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plot_subplot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;optimal_avg&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;100&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;Optimal action %&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;yticks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;20&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;40&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;60&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;80&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;100&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;labels&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;0%&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;20%&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;40%&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;60%&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;80%&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;100%&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tight_layout&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;savefig&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;file_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;vm&#34;&gt;__name__&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;p&#34;&gt;{&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;epsilon&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;const_alpha&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;Q_0&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;},&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;p&#34;&gt;{&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;epsilon&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;const_alpha&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;Q_0&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;},&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;p&#34;&gt;{&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;epsilon&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.01&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;const_alpha&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;Q_0&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;},&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;p&#34;&gt;{&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;epsilon&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;const_alpha&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;Q_0&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;},&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;p&#34;&gt;{&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;epsilon&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;const_alpha&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;Q_0&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;},&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;num_runs&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2000&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;num_steps&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10000&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;k&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;R_avg&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;optimal_avg&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;run_experiments&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_runs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_steps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;k&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plot_results&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;results_stationary.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;R_avg&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;optimal_avg&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;Stationary&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;R_avg&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;optimal_avg&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;run_experiments&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_runs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_steps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;k&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;stationary&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plot_results&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;results_non-stationary.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;R_avg&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;optimal_avg&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;Non-stationary&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Plots for the stationary problems:&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/multi_armed_bandit/results_stationary.png&#34; alt=&#34;results_stationary&#34;&gt;&lt;/p&gt;
&lt;p&gt;The greedy method with estimates initialized to 0 ($Q_0=0$) performed the worst. It can be hard to see in the reward plot, but it performed slightly worse than the $\epsilon$-greedy methods with $\epsilon=0.1$. Of the $\epsilon$-greedy methods with $\epsilon=0.1$, the one with $\alpha=1/N(a)$ performed slightly better than the one with a constant $\alpha=0.1$. The $\epsilon$-greedy method with $\epsilon=0.01$ takes a long time before it gets good since it explores only 1% of the time. Eventually it learns good enough estimates and overtakes the other $\epsilon$-greedy methods since it selects sub-optimal actions less.&lt;/p&gt;
&lt;p&gt;The optimal action percentage plot is easier to read and it clearly shows why the greedy method with $Q_0=0$ performed the worst. It does not explore at all and misses out on selecting the optimal action a lot. However, the greedy method with $Q_0=5$ actually performed the best. The reason is that the high initial estimates forces the agent to try out all actions in the beginning, and thus learns good value estimates. After value estimates are learned, the purely greedy method has the advantage of never selecting actions with lower value estimates. This trick is called optimistic initial values and a similar effect can be achieved by an $\epsilon$-greedy method with large initial $\epsilon$ that reduces to 0 after some time.&lt;/p&gt;
&lt;p&gt;Plots for the non-stationary problems:&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/multi_armed_bandit/results_non-stationary.png&#34; alt=&#34;results_non-stationary&#34;&gt;&lt;/p&gt;
&lt;p&gt;This time the greedy method with $Q_0=5$ did not perform as good. It starts out by learning good value estimates, but can&amp;rsquo;t adapt to the changing values. This can be seen by the decreasing optimal action percentage. The $\epsilon$-greedy method with $\alpha=1/N(a)$ performed much worse on the non-stationary version as predicted. Since the updates to the value estimates gets smaller and smaller, it can&amp;rsquo;t adapt to the changing values.&lt;/p&gt;
&lt;p&gt;There are many other action-value methods, such as methods using Upper-Confidence-Bound action selection (UCB). When estimating values there are varying degrees of uncertainty, depending on how often an action has been taken. UCB quantifies this uncertainty as confidence bounds around value estimates and uses these when selecting actions. There are also methods that doesn&amp;rsquo;t use value estimates, such as gradient bandit algorithms.&lt;/p&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;This has been a short introduction to multi-armed bandits. The model is simple but can be used for modelling repeated decision problems in the real world. An example of such a problem could be clinical trials where the effect of different drugs on patients with the same illness are compared. For each patient it has to be decided what drug to treat them with. The reward could be 1 if the patient recovers and 0 otherwise. Multi-armed bandits are also useful for introducing important concepts in RL, such as value estimation and the exploration vs exploitation tradeoff.&lt;/p&gt;
&lt;p&gt;Thank you for reading and feel free to send me any questions.&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Conditional Generative Adversarial Networks</title>
      <link>https://cfml.se/conditional-generative-adversarial-networks/</link>
      <pubDate>Tue, 22 Jan 2019 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/conditional-generative-adversarial-networks/</guid>
      <description>&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/cgans/cdcgan_gen_data_19.png&#34; alt=&#34;Top image&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;A Generative Adversarial Network (GAN) takes noise as input and generates data that resembles examples from the training set. A Conditional Generative Adversarial Network (CGAN) is a type of GAN that condition on additional information. The additional information makes it possible to have more control over the data generation. CGANs where introduced in &lt;a href=https://arxiv.org/abs/1411.1784&gt;Conditional Generative Adversarial Nets&lt;/a&gt; (Mirza, Osindero, 2014).&lt;/p&gt;
&lt;p&gt;A GAN can be conditioned on many types of information. For example, if you have a training set of animal images and their respective class labels, you can condition the model on animal class labels to be able to generate specific animals. Another example is to condition on images in order to do image-to-image translation, such as in &lt;a href=https://arxiv.org/abs/1611.07004&gt;Image-to-Image Translation with Conditional Adversarial Networks&lt;/a&gt; (Isola, Yan Zhu, Zhou, Efros, 2017).&lt;/p&gt;
&lt;p&gt;In this post we are going to use TensorFlow to implement a CGAN for generating specific digits that look handwritten. The dataset we are going to use is the &lt;a href=&#34;http://yann.lecun.com/exdb/mnist/&#34;&gt;MNIST database of handwritten digits&lt;/a&gt;. In my last post I wrote about &lt;a href=&#34;https://cfml.se/deep-convolutional-generative-adversarial-networks/&#34;&gt;DCGANs&lt;/a&gt; and generated random handwritten digits, but in this post we are going to condition the GAN on digit labels in order to generate specific digits.&lt;/p&gt;
&lt;p&gt;The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/cgan_tensorflow&#34;&gt;https://github.com/CarlFredriksson/cgan_tensorflow&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;theory&#34;&gt;Theory&lt;/h2&gt;
&lt;p&gt;For an introduction to the theory behind GANs, see my previous post &lt;a href=&#34;https://cfml.se/generative-adversarial-networks/&#34;&gt;Generative Adversarial Networks&lt;/a&gt;. The only thing that have changed is that the generator $G$ and the discriminator $D$ now condition on additional information $y$.&lt;/p&gt;
&lt;p&gt;One training step consists of:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Sampling a mini batch of $m$ noise vectors ${z^{(1)},\dots,z^{(m)}}$&lt;/li&gt;
&lt;li&gt;Sampling a mini batch of $m$ training examples ${(x^{(1)},y^{(1)}),\dots,(x^{(m)},y^{(m)})}$&lt;/li&gt;
&lt;li&gt;Updating $D$ by doing one gradient descent step on its loss function:&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;$$
J_D = - \frac{1}{m} \sum_{i=1}^{m} \Big[log \thinspace D\big(x^{(i)}, y^{(i)}\big) + log\big(1 - D\big(G\big(z^{(i)},y^{(i)}\big),y^{(i)}\big)\big)\Big]
$$&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Updating $G$ by doing one gradient descent step on its loss function:&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;$$
J_G = - \frac{1}{m} \sum_{i=1}^{m} \Big[log \thinspace D\big(G\big(z^{(i)},y^{(i)}\big),y^{(i)}\big)\Big]
$$&lt;/p&gt;
&lt;h2 id=&#34;implementation&#34;&gt;Implementation&lt;/h2&gt;
&lt;h3 id=&#34;import-packages&#34;&gt;Import Packages&lt;/h3&gt;
&lt;p&gt;These are the packages we need to import.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;math&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;numpy&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tensorflow&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tf&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;plt&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;matplotlib.gridspec&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;gridspec&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tensorflow.keras.datasets&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mnist&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;sklearn.utils&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shuffle&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;load-mnist-data&#34;&gt;Load MNIST Data&lt;/h3&gt;
&lt;p&gt;You could download the dataset from &lt;a href=&#34;http://yann.lecun.com/exdb/mnist/&#34;&gt;http://yann.lecun.com/exdb/mnist/&lt;/a&gt;, but a convenient alternative is to import it from tensorflow.keras.datasets. Note that in the GAN post we didn&amp;rsquo;t need to store the labels, but this time we are going to condition on them.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;load_mnist_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;load_mnist_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;_&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mnist&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;load_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;In order to visualize the data we can plot some examples.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;NUM_DIGITS&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;SAMPLE_SIZE&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NUM_DIGITS&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;**&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot_sample&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;SAMPLE_SIZE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;output/mnist_data.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;plot_sample&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sample&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;n&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sqrt&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sample&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;fig&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;figure&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;figsize&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;8&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;8&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;*&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;ax&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;subplot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;ax&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;imshow&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sample&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cmap&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_cmap&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;gray&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;ax&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;off&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;ax&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_xticklabels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;ax&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_yticklabels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;fig&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;subplots_adjust&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;hspace&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.025&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;wspace&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.025&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;savefig&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;bbox_inches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;tight&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;clf&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;close&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/dcgans/mnist_data.png&#34; alt=&#34;MNIST data&#34;&gt;&lt;/p&gt;
&lt;h3 id=&#34;preprocessing&#34;&gt;Preprocessing&lt;/h3&gt;
&lt;p&gt;The generator is going to have a tanh activation at the output layer. Thus we want to normalize the pixel values of the training images from the range $(0,255)$ to $(-1,1)$. We also add a channel dimension of depth 1 since the images were loaded in grayscale.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;preprocess_images&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;preprocess_images&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;255&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The labels also need some preprocessing. We will convert them to one-hot vectors and add some dimensions in order to prepare them for convolutional layers. Note that dimensions are not added in the baseline model.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;preprocess_labels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;preprocess_labels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;convert_to_one_hot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;postprocessing&#34;&gt;Postprocessing&lt;/h3&gt;
&lt;p&gt;Later we are going to plot samples of generated images and we will need a function that is the inverse of the preprocess function.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;postprocess_images&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;squeeze&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;255&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;partition-training-data-into-mini-batches&#34;&gt;Partition Training Data into Mini Batches&lt;/h3&gt;
&lt;p&gt;The training data is randomly shuffled and partitioned into mini batches.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;BATCH_SIZE&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;128&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random_mini_batches&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;BATCH_SIZE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;random_mini_batches&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;m&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shuffle&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Partition into mini-batches&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;num_complete_batches&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;math&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;floor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;m&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_complete_batches&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;startIndex&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;endIndex&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;X_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;startIndex&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;endIndex&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Y_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;startIndex&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;endIndex&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Handling the case that the last mini-batch &amp;lt; batch_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;m&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;%&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;!=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;startIndex&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_complete_batches&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;endIndex&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;m&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;X_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;startIndex&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;endIndex&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Y_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;startIndex&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;endIndex&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;baseline-model&#34;&gt;Baseline Model&lt;/h3&gt;
&lt;p&gt;Let’s start by creating a simple CGAN model in order to establish baseline performance. The generator is a simple feedforward neural network (FNN) that takes 100 dimensional noise vectors $Z$ and 10 dimensional one-hot vectors $Y$ as input. It outputs images with dimensions $(28,28,1)$ and pixel values in the range $(-1,1)$.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;generator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;variable_scope&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Generator&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;concat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;128&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;784&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;tanh&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reshape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;28&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;28&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The discriminator is also a simple FNN that takes images $X$ and one-hot vectors $Y$ as input. It outputs single values representing guesses of whether an input image is a real or fake image of the given class $y$.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;discriminator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;variable_scope&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Discriminator&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;flatten&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;concat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;128&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;sigmoid&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;If you want the rest of the code for the baseline model check out the source code. After training the CGAN for 100 epochs I got the following result.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/cgans/cgan_gen_data_99.png&#34; alt=&#34;GAN&#34;&gt;&lt;/p&gt;
&lt;h3 id=&#34;cdcgan-model&#34;&gt;CDCGAN Model&lt;/h3&gt;
&lt;p&gt;Let&amp;rsquo;s now create a deep convolutional CGAN (CDCGAN) model which is better suited for our problem. $Y_{fill}$ will contain the same labels as in $Y$, copied to fill $(28,28,10)$. This is done in order to fit the additional information into the purely convolutional architecture of the discriminator.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;NOISE_DIM&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;100&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NUM_DIGITS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NUM_DIGITS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NOISE_DIM&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;bool&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;G&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_real&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_fake&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;create_cdcgan&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;create_cdcgan&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;G&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;generator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;D_real&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;discriminator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;D_fake&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;discriminator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;G&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;G&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_real&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_fake&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The generator has the same architecture as in the &lt;a href=&#34;https://cfml.se/deep-convolutional-generative-adversarial-networks/&#34;&gt;DCGANs&lt;/a&gt; post, except for the additional input $Y$.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;generator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;variable_scope&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Generator&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;concat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# x.shape: (?, 1, 1, 110)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2d_transpose&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;256&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;7&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;strides&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;valid&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;use_bias&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch_normalization&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;training&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;leaky_relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# x.shape: (?, 7, 7, 256)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2d_transpose&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;128&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;strides&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;same&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;use_bias&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch_normalization&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;training&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;leaky_relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# x.shape: (?, 14, 14, 128)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2d_transpose&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;strides&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;same&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tanh&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# x.shape: (?, 28, 28, 1)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The discriminator also has the same architecture as in the &lt;a href=&#34;https://cfml.se/deep-convolutional-generative-adversarial-networks/&#34;&gt;DCGANs&lt;/a&gt; post, except for the additional input $Y_{fill}$.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;discriminator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;variable_scope&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Discriminator&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;concat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# x.shape: (?, 28, 28, 11)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2d&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;128&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;strides&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;same&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;use_bias&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch_normalization&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;training&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;leaky_relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# x.shape: (?, 14, 14, 128)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2d&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;256&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;strides&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;same&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;use_bias&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch_normalization&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;training&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;leaky_relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# x.shape: (?, 7, 7, 256)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2d&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;7&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;strides&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;valid&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sigmoid&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# x.shape: (?, 1, 1, 1)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h4 id=&#34;create-training-steps&#34;&gt;Create Training Steps&lt;/h4&gt;
&lt;p&gt;We will use the same optimizers as in the &lt;a href=&#34;https://cfml.se/deep-convolutional-generative-adversarial-networks/&#34;&gt;DCGANs&lt;/a&gt; post.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;LEARNING_RATE&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.0002&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;BETA1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;G_loss_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_loss_func&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;create_loss_funcs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;D_real&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_fake&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;G_vars&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_collection&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GraphKeys&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GLOBAL_VARIABLES&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;scope&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Generator&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;D_vars&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_collection&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GraphKeys&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GLOBAL_VARIABLES&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;scope&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Discriminator&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;G_train_step&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;AdamOptimizer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;LEARNING_RATE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;beta1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;BETA1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;minimize&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;G_loss_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;var_list&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;G_vars&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;D_train_step&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;AdamOptimizer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;LEARNING_RATE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;beta1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;BETA1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;minimize&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;D_loss_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;var_list&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;D_vars&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;create_loss_funcs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;D_real&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_fake&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;eps&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;1e-12&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;G_loss_func&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reduce_mean&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;log&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;D_fake&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;eps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;D_loss_func&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reduce_mean&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;log&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;D_real&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;eps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;log&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_fake&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;eps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;G_loss_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_loss_func&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h4 id=&#34;train-model&#34;&gt;Train Model&lt;/h4&gt;
&lt;p&gt;I trained the model for 20 epochs and plotted a sample of generated images after each epoch.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;NUM_EPOCHS&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;20&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Start session&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Session&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;global_variables_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Training loop&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;NUM_EPOCHS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;generate_Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_batch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NOISE_DIM&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;Y_fill_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ones&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_batch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;c1&#34;&gt;# Compute losses&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;G_loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_loss&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;G_loss_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_loss_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Epoch [&lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{0}&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;/&lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{1}&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;] - G_loss: &lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{2}&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;, D_loss: &lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{3}&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;format&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NUM_EPOCHS&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;G_loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;c1&#34;&gt;# Run training steps&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;_&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;G_train_step&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;_&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;D_train_step&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Plot generated images&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Y_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ones&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;NUM_DIGITS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NUM_DIGITS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Y_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;arange&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;NUM_DIGITS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;T&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reshape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;SAMPLE_SIZE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;astype&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Y_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;preprocess_labels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Y_fill_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ones&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;SAMPLE_SIZE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;generate_Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;SAMPLE_SIZE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NOISE_DIM&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;gen_imgs&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;G&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_fill_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;gen_imgs&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;postprocess_images&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;gen_imgs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot_sample&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;gen_imgs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;output/cdcgan/cdcgan_gen_data_&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;generate_Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;uniform&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;low&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;high&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Images generated after the 20th epoch:&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/cgans/cdcgan_gen_data_19.png&#34; alt=&#34;Generated epoch 19&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;It was interesting to see that the results were considerably better for the conditioned version of the DCGAN compared to the unconditioned one in the &lt;a href=&#34;https://cfml.se/deep-convolutional-generative-adversarial-networks/&#34;&gt;DCGANs&lt;/a&gt; post. This is probably because the additional information forces the generator and discriminator to be more specific. The unconditioned DCGAN generated some images that looked like combinations of two digits, which barely happened with the conditioned version.&lt;/p&gt;
&lt;p&gt;Thank you for reading and feel free to send me any questions.&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Deep Convolutional Generative Adversarial Networks</title>
      <link>https://cfml.se/deep-convolutional-generative-adversarial-networks/</link>
      <pubDate>Sun, 16 Dec 2018 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/deep-convolutional-generative-adversarial-networks/</guid>
      <description>&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/dcgans/real_vs_generated.png&#34; alt=&#34;Real vs Generated&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;I recently wrote a post about &lt;a href=&#34;https://cfml.se/generative-adversarial-networks/&#34;&gt;Generative Adversarial Networks&lt;/a&gt; (GANs). A Deep Convolutional Generative Adversarial Network (DCGAN) is a type of GAN where the generator is a deep neural network utilizing transpose convolutions and the discriminator is a deep convolutional neural network (CNN). DCGANs were introduced in the paper &lt;a href=https://arxiv.org/abs/1511.06434&gt;Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks&lt;/a&gt; (Radford, Metz, &amp;amp; Chintala, 2016).&lt;/p&gt;
&lt;p&gt;In this post I&amp;rsquo;m going to show how you can use TensorFlow to implement a DCGAN for generating digits that look handwritten. The dataset I used is the &lt;a href=&#34;http://yann.lecun.com/exdb/mnist/&#34;&gt;MNIST database of handwritten digits&lt;/a&gt;, which I also used in my post about &lt;a href=&#34;https://cfml.se/digit-recognition/&#34;&gt;Digit Recognition&lt;/a&gt;.&lt;/p&gt;
&lt;p&gt;For an introduction to the theory behind DCGANs I refer you to my previous post about &lt;a href=&#34;https://cfml.se/generative-adversarial-networks/&#34;&gt;Generative Adversarial Networks&lt;/a&gt;. The competition between the generator and discriminator remains the same. The generator takes noise as input and generates fake images, and the discriminator tries to distinguish fake images from real images in the training set. We don&amp;rsquo;t have to change the loss functions, which is a great feature of GANs. The only difference with DCGANs is how the generator and discriminator are implemented.&lt;/p&gt;
&lt;p&gt;The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/dcgan_tensorflow&#34;&gt;https://github.com/CarlFredriksson/dcgan_tensorflow&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;implementation&#34;&gt;Implementation&lt;/h2&gt;
&lt;h3 id=&#34;import-packages&#34;&gt;Import Packages&lt;/h3&gt;
&lt;p&gt;These are the packages we need to import.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;math&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;numpy&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tensorflow&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tf&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;plt&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;matplotlib.gridspec&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;gridspec&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tensorflow.keras.datasets&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mnist&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;load-mnist-data&#34;&gt;Load MNIST Data&lt;/h3&gt;
&lt;p&gt;You could download the dataset from &lt;a href=&#34;http://yann.lecun.com/exdb/mnist/&#34;&gt;http://yann.lecun.com/exdb/mnist/&lt;/a&gt;, but a convenient alternative is to import it from tensorflow.keras.datasets. Since we are doing unsupervised learning we don&amp;rsquo;t need to store the labels.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;load_mnist_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;load_mnist_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;_&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;_&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mnist&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;load_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;In order to visualize the data we can plot some examples.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;SAMPLE_SIZE&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;100&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot_sample&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;SAMPLE_SIZE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;output/mnist_data.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;plot_sample&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sample&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;n&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sqrt&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sample&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;fig&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;figure&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;figsize&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;8&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;8&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;*&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;ax&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;subplot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;ax&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;imshow&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sample&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cmap&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_cmap&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;gray&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;ax&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;off&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;ax&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_xticklabels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;ax&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;set_yticklabels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;fig&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;subplots_adjust&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;hspace&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.025&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;wspace&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.025&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;savefig&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;bbox_inches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;tight&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;clf&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;close&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/dcgans/mnist_data.png&#34; alt=&#34;MNIST data&#34;&gt;&lt;/p&gt;
&lt;h3 id=&#34;preprocessing&#34;&gt;Preprocessing&lt;/h3&gt;
&lt;p&gt;The generator is going to have a tanh activation at the output layer. Thus we want to normalize the pixel values of the training images from the range $(0,255)$ to $(-1,1)$. We also add a channel dimension of depth 1 since the images were loaded in grayscale.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;preprocess_images&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;preprocess_images&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;255&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;postprocessing&#34;&gt;Postprocessing&lt;/h3&gt;
&lt;p&gt;Later we are going to plot samples of generated images and we will need a function that is the inverse of the preprocess function.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;postprocess_images&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;squeeze&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;255&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;partition-training-data-into-mini-batches&#34;&gt;Partition Training Data into Mini Batches&lt;/h3&gt;
&lt;p&gt;The training data is randomly shuffled and partitioned into mini batches.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;BATCH_SIZE&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;128&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random_mini_batches&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;BATCH_SIZE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;random_mini_batches&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;m&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shuffle&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Partition into mini-batches&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;num_complete_batches&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;math&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;floor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;m&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_complete_batches&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Handling the case that the last mini-batch &amp;lt; batch_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;m&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;%&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;!=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_complete_batches&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;m&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;baseline-model&#34;&gt;Baseline Model&lt;/h3&gt;
&lt;p&gt;Let&amp;rsquo;s start by creating a simple GAN model in order to establish baseline performance. The generator is a simple feedforward neural network (FNN) that takes 100 dimensional noise vectors $Z$ as input, and outputs images with dimensions $(28,28,1)$ and pixel values in the range $(-1,1)$. The discriminator is also a simple FNN that takes images of dimensions $(28,28,1)$ as input and outputs single values representing guesses of whether an input image is real or fake.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;generator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;variable_scope&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Generator&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;128&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;784&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;tanh&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reshape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;28&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;28&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;discriminator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;variable_scope&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Discriminator&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;flatten&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;128&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;sigmoid&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;If you want the rest of the code for the baseline model check out the source code. After training the GAN for 100 epochs I got the following result.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/dcgans/gan_gen_data_99.png&#34; alt=&#34;GAN&#34;&gt;&lt;/p&gt;
&lt;h3 id=&#34;dcgan-model&#34;&gt;DCGAN Model&lt;/h3&gt;
&lt;p&gt;Let&amp;rsquo;s now create a DCGAN model which is better suited for our problem.&lt;/p&gt;
&lt;h4 id=&#34;create-generator&#34;&gt;Create Generator&lt;/h4&gt;
&lt;p&gt;The generator takes 100 dimensional noise vectors $Z$ as input and outputs images with dimensions $(28,28,1)$ and pixel values in the range $(-1,1)$. The network architecture is heavily inspired by the architecture guidelines suggested in the original DCGAN paper &lt;a href=https://arxiv.org/abs/1511.06434&gt;(Radford, Metz, &amp;amp; Chintala, 2016)&lt;/a&gt;.&lt;/p&gt;
&lt;p&gt;We will use transpose convolutions as a way to do upsampling in a learnable way. Traditional methods of upsampling like nearest-neighbor interpolation or bilinear interpolation are methods without learnable parameters. For generating images we want the network to be able to learn its own upsampling method in order to be able to create convincing fake images.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;generator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;variable_scope&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Generator&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# x.shape: (?, 1, 1, 100)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2d_transpose&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;256&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;7&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;strides&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;valid&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;use_bias&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch_normalization&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;training&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;leaky_relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# x.shape: (?, 7, 7, 256)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2d_transpose&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;128&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;strides&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;same&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;use_bias&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch_normalization&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;training&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;leaky_relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# x.shape: (?, 14, 14, 128)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2d_transpose&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;strides&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;same&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tanh&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# x.shape: (?, 28, 28, 1)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h4 id=&#34;create-discriminator&#34;&gt;Create Discriminator&lt;/h4&gt;
&lt;p&gt;The discriminator is a CNN that takes images of dimensions $(28,28,1)$ as input and outputs single values representing guesses of whether an input image is real or fake. The network looks like a normal CNN for binary image classification with a few modifications. Most notably, we use strided convolutions instead of pooling layers for downsampling. This is so the network can learn its own downsampling, similarly to the upsampling in the generator.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;discriminator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;variable_scope&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Discriminator&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# x.shape: (?, 28, 28, 1)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2d&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;128&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;strides&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;same&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;use_bias&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch_normalization&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;training&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;leaky_relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# x.shape: (?, 14, 14, 128)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2d&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;256&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;strides&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;same&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;use_bias&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch_normalization&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;training&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;leaky_relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# x.shape: (?, 7, 7, 256)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2d&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;7&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;strides&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;valid&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sigmoid&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# x.shape: (?, 1, 1, 1)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h4 id=&#34;create-dcgan&#34;&gt;Create DCGAN&lt;/h4&gt;
&lt;p&gt;The overall structure is the same as in a regular GAN. The generator $G$ generates fake examples from noise. The discriminator $D$ takes both real examples from training data and fake examples from $G$ as input, and outputs guesses on which examples are real and which are fake.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/dcgans/GAN.png&#34; alt=&#34;GAN&#34;&gt;&lt;/p&gt;
&lt;p&gt;Note that I used the placeholder &lt;code&gt;is_training&lt;/code&gt; mostly out of habit. It&amp;rsquo;s going to be set to true for both training and generation of images. When doing supervised learning there is a clear distinction between training a model and using it for inference. When batch normalization layers are applied during training, the normalization step is done using mean and variance of the current mini-batch. When doing inference, the normalization step uses population mean and population variance instead, which are statistics estimated during training using moving mean and variance.&lt;/p&gt;
&lt;p&gt;Since we are generating new data from noise and not doing any inference, there is no need to change the batch norm layers from training mode to inference mode. I tried both and found that keeping training mode on worked the best. When images where generated with inference mode on (&lt;code&gt;is_training&lt;/code&gt; set to false) they were slightly blurry, and more importantly they ended up generating the same number (9) in most cases.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;NOISE_DIM&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;100&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NOISE_DIM&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;bool&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;G&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_real&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_fake&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;create_gan&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;create_gan&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;G&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;generator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;D_real&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;discriminator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;D_fake&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;discriminator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;G&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;G&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_real&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_fake&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h4 id=&#34;create-training-steps&#34;&gt;Create Training Steps&lt;/h4&gt;
&lt;p&gt;Just like in a regular GAN we will use separate training steps for the generator and discriminator. We will also use the same loss functions as in a regular GAN. As suggested in the DCGAN paper, we will use Adam optimizers with parameters $\beta_1$ and learning rate set to 0.5 and 0.0002 respectively.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;BETA1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;LEARNING_RATE&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.0002&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;G_loss_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_loss_func&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;create_loss_funcs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;D_real&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_fake&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;G_vars&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_collection&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GraphKeys&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GLOBAL_VARIABLES&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;scope&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Generator&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;D_vars&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_collection&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GraphKeys&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GLOBAL_VARIABLES&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;scope&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Discriminator&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;G_train_step&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;AdamOptimizer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;LEARNING_RATE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;beta1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;BETA1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;minimize&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;G_loss_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;var_list&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;G_vars&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;D_train_step&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;AdamOptimizer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;LEARNING_RATE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;beta1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;BETA1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;minimize&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;D_loss_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;var_list&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;D_vars&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;create_loss_funcs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;D_real&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_fake&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;eps&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;1e-12&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;G_loss_func&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reduce_mean&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;log&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;D_fake&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;eps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;D_loss_func&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reduce_mean&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;log&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;D_real&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;eps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;log&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_fake&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;eps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;G_loss_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_loss_func&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h4 id=&#34;train-model&#34;&gt;Train Model&lt;/h4&gt;
&lt;p&gt;I trained the model for 20 epochs and plotted a sample of generated images after each epoch.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;NUM_EPOCHS&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;20&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Start session&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Session&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;global_variables_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Training loop&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;NUM_EPOCHS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_batch&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;generate_Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_batch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NOISE_DIM&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;c1&#34;&gt;# Compute losses&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;G_loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_loss&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;G_loss_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_loss_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Epoch [&lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{0}&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;/&lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{1}&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;] - G_loss: &lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{2}&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;, D_loss: &lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{3}&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;format&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NUM_EPOCHS&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;G_loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;D_loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;c1&#34;&gt;# Run training steps&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;_&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;G_train_step&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;_&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;D_train_step&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Plot generated images&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;generate_Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;SAMPLE_SIZE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NOISE_DIM&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;gen_imgs&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;G&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;is_training&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;gen_imgs&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;postprocess_images&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;gen_imgs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot_sample&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;gen_imgs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;output/dcgan/dcgan_gen_data_&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;generate_Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;uniform&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;low&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;high&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Images generated after the 20th epoch:&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/dcgans/dcgan_gen_data_19.png&#34; alt=&#34;Generated epoch 19&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;This was a fun project and I&amp;rsquo;m happy about how it turned out. Not all of the generated images could fool a human, but most of them look like they could be handwritten digits to me. Although generating handwritten digits is not the most exiting objective, it serves as a good introductory task for learning how to implement and tune DCGANs.&lt;/p&gt;
&lt;p&gt;Thank you for reading and feel free to send me any questions.&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Generative Adversarial Networks</title>
      <link>https://cfml.se/generative-adversarial-networks/</link>
      <pubDate>Mon, 12 Nov 2018 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/generative-adversarial-networks/</guid>
      <description>&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/generative_adversarial_networks/counterfeiter_vs_cop.jpg&#34; alt=&#34;Counterfeiter vs Cop&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;Generative Adversarial Networks (GANs) have been getting a lot of attention recently. They have been used for many different tasks, such as generating anime characters (Jin et al., 2017, &lt;a href=&#34;https://arxiv.org/abs/1708.05509)&#34;&gt;https://arxiv.org/abs/1708.05509)&lt;/a&gt;, generating photos of fake celebrities (Karras, Aila, Laine, &amp;amp; Lehtinen, 2017, &lt;a href=&#34;https://arxiv.org/abs/1710.10196)&#34;&gt;https://arxiv.org/abs/1710.10196)&lt;/a&gt;, increasing resolution in photos (Ledig et al., 2016, &lt;a href=&#34;https://arxiv.org/abs/1609.04802)&#34;&gt;https://arxiv.org/abs/1609.04802)&lt;/a&gt;, and general image-to-image translation (Isola, Yan Zhu, Zhou, Efros, 2017, &lt;a href=&#34;https://arxiv.org/abs/1611.07004)&#34;&gt;https://arxiv.org/abs/1611.07004)&lt;/a&gt;.&lt;/p&gt;
&lt;p&gt;GANs were introduced in (Goodfellow et al., 2014, &lt;a href=&#34;https://arxiv.org/abs/1406.2661&#34;&gt;https://arxiv.org/abs/1406.2661&lt;/a&gt;) as an unsupervised learning model, trained on data without labels. The model learns to take noise as input and output data that resembles examples from the training set. There are versions of GANs that can utilize additional information as input, such as conditional GANs (CGANs), but in this post I&amp;rsquo;m going to focus on the basic version.&lt;/p&gt;
&lt;p&gt;I&amp;rsquo;m going to explain some of the theory behind GANs and then show how you can implement a simple version in TensorFlow. The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/gan_tensorflow&#34;&gt;https://github.com/CarlFredriksson/gan_tensorflow&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;theory&#34;&gt;Theory&lt;/h2&gt;
&lt;p&gt;A GAN consists of two parts, a generator $G$ and a discriminator $D$. The objective for $G$ is to take random noise $z$ as input and generate a fake example that looks like it comes from the real training data. The objective for $D$ is to take an example $x$ from the training data or a fake example $G(z)$ as input and output either 0 or 1. If the example is from the training data $D$ should output 1, and if the example is fake $D$ should output 0. In other words, $D$ is a binary classifier on whether the input is real (part of the training data) or fake (generated by $G$).&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/generative_adversarial_networks/GAN.png&#34; alt=&#34;GAN&#34;&gt;&lt;/p&gt;
&lt;p&gt;Note that $G$ is what we are ultimately interested in. After training, $G$ is what we will use in our applications to generate new data. $D$ is there to help train $G$. Training using an adversary in this way is a powerful idea that lets us define simple loss functions for $G$ and $D$ that don&amp;rsquo;t need to be changed between different tasks. GANs remove the need for handcrafting specific loss functions for many problems.&lt;/p&gt;
&lt;h3 id=&#34;minimax-game&#34;&gt;Minimax Game&lt;/h3&gt;
&lt;p&gt;$G$ and $D$ have directly competing objectives, and the situation can be modelled as a minimax zero-sum game where $G$ and $D$ are the players. When $G$ learns to create better fakes, it forces $D$ to get better at detecting them and vice versa. The competition makes $G$ and $D$ better and better until a steady state is reached.&lt;/p&gt;
&lt;p&gt;An analogy could be a cop vs a criminal making counterfeit money. As the cop gets better at detecting fake bills, the criminal has to learn more complex counterfeiting methods, which in turn forces the cop to get even better, and so on.&lt;/p&gt;
&lt;p&gt;The game that $G$ and $D$ plays can be modelled by the value function $V(D,G)$:&lt;/p&gt;
&lt;p&gt;$$
\underset{G}{min} \thinspace \underset{D}{max} \thinspace V(D,G) = \mathbb{E}_{x \sim p_{data}(x)}[log \thinspace D(x)] + \mathbb{E}_{z \sim p_{z}(z)}[log(1 - D(G(z)))] \tag 1
$$&lt;/p&gt;
&lt;h3 id=&#34;training&#34;&gt;Training&lt;/h3&gt;
&lt;p&gt;One training step consists of:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Sample a mini batch of $m$ noise vectors ${z^{(1)},\dots,z^{(m)}}$&lt;/li&gt;
&lt;li&gt;Sample a mini batch of $m$ training examples ${x^{(1)},\dots,x^{(m)}}$&lt;/li&gt;
&lt;li&gt;Update $D$ by doing one gradient descent step on its loss function:&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;$$
J_D = - \frac{1}{m} \sum_{i=1}^{m} \Big[log \thinspace D\big(x^{(i)}\big) + log\big(1 - D\big(G\big(z^{(i)}\big)\big)\big)\Big] \tag 2
$$&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Update $G$ by doing one gradient descent step on its loss function:&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;$$
J_G = - \frac{1}{m} \sum_{i=1}^{m} \Big[log \thinspace D\big(G\big(z^{(i)}\big)\big)\Big] \tag 3
$$&lt;/p&gt;
&lt;p&gt;I chose to minimize $-log \thinspace D(G(z))$, rather than $log(1 - D(G(z)))$. This is suggested in the original GAN paper to provide stronger gradients for $G$ early in training, while keeping the same fixed point dynamics of $G$ and $D$ as in equation 1. Since the performance of $G$ is poor in the early stages, $D$ can easily distinguish fake examples from real training data, which results in small gradients for $G$. If needed, the imbalance can be combated further by slowing down the learning of $D$. This can be done by running more than one training step for $G$ for every step of $D$ or lowering the learning rate of $D$.&lt;/p&gt;
&lt;h2 id=&#34;implementation&#34;&gt;Implementation&lt;/h2&gt;
&lt;p&gt;Let us now implement a GAN. I chose to keep it simple in order to focus on the basic ideas. We are going to generate a two-dimensional dataset using a quadratic function, and implement $G$ and $D$ as small feedforward neural networks (FNNs).&lt;/p&gt;
&lt;h3 id=&#34;import-packages&#34;&gt;Import Packages&lt;/h3&gt;
&lt;p&gt;These are the packages we need to import.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;math&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;numpy&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tensorflow&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tf&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;plt&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;generate-data&#34;&gt;Generate Data&lt;/h3&gt;
&lt;p&gt;The data is generated using the function $y=x^2 + 5$. The data contains two-dimensional points without labels that are shuffled and partitioned into mini-batches for training.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;generate_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;plot_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;random_mini_batches&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;BATCH_SIZE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;generate_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1000&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;scale&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;scale&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;*&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random_sample&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;**&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;array&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;T&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;plot_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;o&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;savefig&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;output/data.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;bbox_inches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;tight&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;clf&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;random_mini_batches&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;m&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shuffle&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Partition into mini-batches&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;num_complete_batches&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;math&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;floor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;m&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_complete_batches&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Handling the case that the last mini-batch &amp;lt; batch_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;m&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;%&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;!=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_complete_batches&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;m&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/generative_adversarial_networks/data.png&#34; alt=&#34;Data&#34;&gt;&lt;/p&gt;
&lt;h3 id=&#34;create-model&#34;&gt;Create Model&lt;/h3&gt;
&lt;p&gt;It is time to create the GAN. Let&amp;rsquo;s start by defining two placeholders &lt;code&gt;X&lt;/code&gt; and &lt;code&gt;Z&lt;/code&gt;. &lt;code&gt;X&lt;/code&gt; is the placeholder for batches of training data, and &lt;code&gt;Z&lt;/code&gt; is the placeholder for batches of random noise.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Let&amp;rsquo;s create two functions for defining the generator $G$ and the discriminator $D$.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;generator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;hidden_sizes&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;variable_scope&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;GAN/Generator&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;h1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;hidden_sizes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;leaky_relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;h2&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;h1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;hidden_sizes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;leaky_relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;out&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;h2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;out&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;discriminator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;hidden_sizes&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;variable_scope&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;GAN/Discriminator&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;h1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;hidden_sizes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;leaky_relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;h2&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;h1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;hidden_sizes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;leaky_relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;out&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;h2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sigmoid&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;out&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Using our &lt;code&gt;generator&lt;/code&gt; and &lt;code&gt;discriminator&lt;/code&gt; functions, we can create the GAN model. The tensors we are interested in is the output of $G$, the output of $D$ when real training data is used as input, and the output of $D$ when the input is generated fakes from $G$. The &lt;code&gt;reuse&lt;/code&gt; flag is set to true when we call the &lt;code&gt;discriminator&lt;/code&gt; function for the second time, since we only want one set of variables for $D$.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;gen_out&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;disc_out_real&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;disc_out_fake&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;create_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;create_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;gen_out&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;generator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;disc_out_real&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;discriminator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;disc_out_fake&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;discriminator&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;gen_out&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;reuse&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;gen_out&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;disc_out_real&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;disc_out_fake&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;train-model&#34;&gt;Train Model&lt;/h3&gt;
&lt;p&gt;To train the model we need to create loss functions, which are defined in equations 2 and 3. Note the addition of the small value &lt;code&gt;EPS&lt;/code&gt; in order to avoid taking the log of 0.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;EPS&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;1e-12&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;disc_loss&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reduce_mean&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;log&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;disc_out_real&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;EPS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;log&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;disc_out_fake&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;EPS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;gen_loss&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reduce_mean&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;log&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;disc_out_fake&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;EPS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Since we are alternating between training $G$ and $D$, we need to create a separate training step for each.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;LEARNING_RATE&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.001&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;gen_vars&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_collection&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GraphKeys&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GLOBAL_VARIABLES&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;scope&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;GAN/Generator&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;disc_vars&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_collection&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GraphKeys&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GLOBAL_VARIABLES&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;scope&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;GAN/Discriminator&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;gen_train_step&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;AdamOptimizer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;LEARNING_RATE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;minimize&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;gen_loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;var_list&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;gen_vars&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;disc_train_step&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;AdamOptimizer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;LEARNING_RATE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;minimize&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;disc_loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;var_list&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;disc_vars&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Now we are ready to train the model. The performance of the generator is plotted every 250th epoch, and some of the results for one training run can be seen below.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;NUM_EPOCHS&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;5000&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;BATCH_SIZE&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;32&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Start session&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Session&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;global_variables_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Training loop&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;NUM_EPOCHS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_batch&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;generate_Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_batch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;_&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;disc_loss_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;disc_train_step&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;disc_loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;_&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;gen_loss_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;gen_train_step&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;gen_loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;%&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;250&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Epoch &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34; - plotting generated data&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;plot_generated_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;lambda&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;gen_out&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;}),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Training finished - plotting generated data&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plot_generated_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;lambda&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;gen_out&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;}),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NUM_EPOCHS&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;plot_generated_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;gen_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;generate_Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;gen_data&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;gen_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Z_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;o&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;gen_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;gen_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;o&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;title&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Epoch &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;savefig&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;output/gen_data_&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;bbox_inches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;tight&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;clf&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/generative_adversarial_networks/gen_data_0.png&#34; alt=&#34;Generated data 0&#34;&gt;&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/generative_adversarial_networks/gen_data_250.png&#34; alt=&#34;Generated data 250&#34;&gt;&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/generative_adversarial_networks/gen_data_1000.png&#34; alt=&#34;Generated data 1000&#34;&gt;&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/generative_adversarial_networks/gen_data_4999.png&#34; alt=&#34;Generated data 4999&#34;&gt;&lt;/p&gt;
&lt;p&gt;After 5000 epochs, $G$ has learned to generate data points that resemble the training data pretty well.&lt;/p&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;GANs are based on the idea of adversarial learning, and are capable of incredible performance on a wide variety of problems. In order to focus on the basic properties of GANs, this post showed how to implement a simple GAN and apply it on a simple task. I urge you to check out some of the recent papers to see really cool applications of GANs if you haven&amp;rsquo;t already.&lt;/p&gt;
&lt;p&gt;I hope this post was useful to you, and feel free to send me any questions.&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Sentiment Classification</title>
      <link>https://cfml.se/sentiment-classification/</link>
      <pubDate>Tue, 23 Oct 2018 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/sentiment-classification/</guid>
      <description>&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/sentiment_classification/top_img.png&#34; alt=&#34;Sentiment classification&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;Sentiment classification is a Natural Language Processing (NLP) task where the objective is to predict the sentiment of some text input. For example, predicting a one to five star rating given a product review, or predicting the emotion of a text message. In this post I&amp;rsquo;m going to show you how I implemented sentiment classification for predicting whether a movie review is positive or negative.&lt;/p&gt;
&lt;p&gt;The dataset I used is the &amp;ldquo;Large Movie Review Dataset&amp;rdquo; which can be found at &lt;a href=&#34;http://ai.stanford.edu/~amaas/data/sentiment/&#34;&gt;http://ai.stanford.edu/~amaas/data/sentiment/&lt;/a&gt;. It was introduced in the paper &lt;a href=http://ai.stanford.edu/~amaas/papers/wvSent_acl2011.pdf&gt;Learning Word Vectors for Sentiment Analysis&lt;/a&gt; by Andrew L. Maas, Raymond E. Daly, Peter T. Pham, Dan Huang, Andrew Y. Ng, and Christopher Potts. (ACL 2011). The dataset contains 25000 highly polar (distinctly positive or negative) movie reviews for training and 25000 for testing. I kept the training data as my training set, but split the test data into a validation set used to evaluate different models, and a test set used to evaluate if the final model generalizes well.&lt;/p&gt;
&lt;p&gt;The project was implemented in Keras, and the source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/sentiment_classification&#34;&gt;https://github.com/CarlFredriksson/sentiment_classification&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;implementation&#34;&gt;Implementation&lt;/h2&gt;
&lt;h3 id=&#34;load-data&#34;&gt;Load Data&lt;/h3&gt;
&lt;p&gt;After downloading the dataset, you will see that both the training and test data is divided into directories of negative and positive examples. Each example has its own file, which means that we need to open every file, read it, and append the text to the correct input set (&lt;code&gt;X_train&lt;/code&gt; or &lt;code&gt;X_test&lt;/code&gt;). We also need to append a label of 0 or 1 to the corresponding output sets &lt;code&gt;Y_train&lt;/code&gt; or &lt;code&gt;Y_test&lt;/code&gt; for each text example. I chose to use 0 for examples with negative sentiment, and 1 for examples with positive sentiment.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;os&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sc_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;load_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;load_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# neg = 0, pos = 1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[],&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[],&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[],&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Prepare paths&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;data_path&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;join&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;getcwd&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(),&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;data&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;train_neg_path&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;join&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;train&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;neg&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;train_pos_path&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;join&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;train&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;pos&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;test_neg_path&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;join&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;test&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;neg&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;test_pos_path&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;join&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;test&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;pos&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Load training data&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;filename&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;listdir&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train_neg_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;open&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;join&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train_neg_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;filename&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;encoding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;utf8&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;review&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;readline&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;review&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;filename&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;listdir&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train_pos_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;open&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;join&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train_pos_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;filename&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;encoding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;utf8&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;review&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;readline&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;review&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Training data loaded&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Load test data&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;filename&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;listdir&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;test_neg_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;open&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;join&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;test_neg_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;filename&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;encoding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;utf8&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;review&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;readline&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;review&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;filename&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;listdir&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;test_pos_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;open&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;join&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;test_pos_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;filename&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;encoding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;utf8&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;review&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;readline&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;review&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Test data loaded&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;preprocess-data&#34;&gt;Preprocess Data&lt;/h3&gt;
&lt;p&gt;The first thing we want to do is to encode texts as sequences of numbers, where each number corresponds to a word index. To create such word indices we need to know what words are used in the dataset. Fortunately, there is a vocabulary file provided with the dataset which contains all used words.&lt;/p&gt;
&lt;p&gt;To help with the task of encoding text as sequences of word indices we can use the Keras class &lt;code&gt;Tokenizer&lt;/code&gt;. The &lt;code&gt;Tokenizer&lt;/code&gt; assigns an index for each word it considers distinct in the vocabulary. There is a default filter that disregards the symbols &lt;code&gt;!&amp;quot;#$%&amp;amp;()*+,-./:;&amp;lt;=&amp;gt;?@[\]^_`{|}~&lt;/code&gt;, which makes the &lt;code&gt;Tokenizer&lt;/code&gt; not consider some lines in the vocabulary as words. For example, the vocabulary contains many emojis (such as &lt;code&gt;:)&lt;/code&gt; and &lt;code&gt;:-p&lt;/code&gt;) which are disregarded because of this filter. Changing the default filter might be a worthwhile experiment, but I did not bother for this project.&lt;/p&gt;
&lt;p&gt;After initializing the tokenizer on our vocabulary using &lt;code&gt;tokenizer.fit_on_texts(vocab)&lt;/code&gt;, we can convert our texts to sequences of word indices. The sequences are then padded to make them equal length, which is important for training in batches. Sequences that are too short get zeroes added at the end, and sequences that are too long gets clipped.&lt;/p&gt;
&lt;p&gt;After converting the processed training and test data to numpy arrays, the data is shuffled and the test data is divided into equally sized validation and test sets.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;numpy&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;keras.preprocessing.text&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Tokenizer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;keras.preprocessing.sequence&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pad_sequences&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;sklearn.utils&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shuffle&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;INPUT_LENGTH&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;100&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tokenizer&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sc_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;preprocess_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;INPUT_LENGTH&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;preprocess_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;input_len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Load vocabulary&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;vocab_path&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;join&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;getcwd&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(),&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;data&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;imdb.vocab&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;open&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;vocab_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;encoding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;utf8&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;vocab&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;read&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;splitlines&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Vocabulary length before tokenizing: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;vocab&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Prepare tokenizer, the out of vocabulary token for GloVe is &amp;#34;unk&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;tokenizer&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Tokenizer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;oov_token&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;unk&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;tokenizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;fit_on_texts&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;vocab&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;vocab_len&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tokenizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;word_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Vocabulary length after tokenizing: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;vocab_len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Convert text to sequences of indices&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tokenizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;texts_to_sequences&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tokenizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;texts_to_sequences&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Pad sequences with zeros so they all have the same length&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pad_sequences&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;maxlen&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;input_len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;post&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pad_sequences&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;maxlen&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;input_len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;post&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Convert training and test data to numpy arrays&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;array&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;float32&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;array&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;float32&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;array&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;float32&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;array&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;float32&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Shuffle training and test data&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shuffle&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shuffle&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Split test data into validation and test sets&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;split_index&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;split_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;split_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;split_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;split_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tokenizer&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;create-embedding-matrix&#34;&gt;Create Embedding Matrix&lt;/h3&gt;
&lt;p&gt;A common tool in NLP is to use word embeddings, which are vectors that characterizes words. Each word has a corresponding vector, and words with similar meaning are expected to have embedding vectors that are close to each other in the vector space. Using word embeddings has many advantages, one of which is making models more robust to rarely seen words. Word embeddings can be learned as part of the model, or you can use pretrained embeddings. I tried both approaches and found that using pretrained embeddings was a better choice for this project since it made the model overfit less. There are many types of word embeddings. I decided to use pretrained GloVe embeddings which can be downloaded from &lt;a href=&#34;https://nlp.stanford.edu/projects/glove/&#34;&gt;https://nlp.stanford.edu/projects/glove/&lt;/a&gt;.&lt;/p&gt;
&lt;p&gt;In order to use the embeddings in our models, we need to create a matrix where each row contains an embedding vector for the word that has a word index equal to the index of that row.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;embedding_matrix&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sc_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;create_embedding_matrix&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tokenizer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;create_embedding_matrix&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tokenizer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Load GloVe embedding vectors&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;embedding_path&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;join&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;getcwd&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(),&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;data&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;glove.6B&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;glove.6B.100d.txt&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;word_to_embedding&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;{}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;open&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;embedding_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;encoding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;utf8&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;line&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;readlines&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;values&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;line&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;split&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;word&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;embedding_vec&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;asarray&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;float32&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;word_to_embedding&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;word&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;embedding_vec&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Embedding vectors loaded&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Create embedding matrix&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;embedding_vec_dim&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;100&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;vocab_len&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tokenizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;word_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;embedding_matrix&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;vocab_len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;embedding_vec_dim&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;word&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tokenizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;word_index&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;items&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;embedding_vec&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;word_to_embedding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;word&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;embedding_vec&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;is&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;not&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;embedding_matrix&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;embedding_vec&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Embedding matrix created&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;embedding_matrix&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;create-models&#34;&gt;Create Models&lt;/h3&gt;
&lt;p&gt;Now it is time to create some models. First let&amp;rsquo;s create a very simple model to get a baseline performance. The baseline model simply averages the word embeddings of an input. The Embedding layer is used to convert sequences of word indices to sequences of embedding vectors. Not that since we are using pretrained embeddings that we don&amp;rsquo;t want to train further, the weights are initialized to our embedding matrix and the &lt;code&gt;trainable&lt;/code&gt; flag is set to false. Also note that the parameter &lt;code&gt;mask_zero&lt;/code&gt; is set to true. This masks the padded zeros in the input sequences, which is good since we only need them to make sure that the input lengths are equal. The output is a dense layer with a single output and sigmoid activation since we are doing binary classification.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;keras.models&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Sequential&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Input&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Model&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;keras.layers&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Flatten&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Embedding&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Average&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Activation&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Lambda&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Dropout&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;LSTM&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Bidirectional&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;keras.initializers&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Constant&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;keras.backend&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;K&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;keras&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;regularizers&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;create_baseline_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;embedding_matrix&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;input_len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Sequential&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Embedding&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;embedding_matrix&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;embedding_matrix&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;embeddings_initializer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Constant&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;embedding_matrix&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;input_length&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;input_len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;trainable&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mask_zero&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Lambda&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;k&#34;&gt;lambda&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;K&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mean&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;sigmoid&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;compile&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;adam&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;binary_crossentropy&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;metrics&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;accuracy&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Let&amp;rsquo;s now create a more complex model. This model uses a recurrent neural network (RNN) with LSTM layers, which is very common for many types of sequential data. I tried several hyperparameter settings such as varying the number of LSTM, Dropout, and Dense layers. I also varied the output sizes of the LSTM and Dense layers, and the rate of the Dropout layers. The model I ended up with, by choosing the one with the highest accuracy on the validation set, has two layers of LSTM units and one extra Dense layer. Note that the parameter &lt;code&gt;return_sequences&lt;/code&gt; is set to true for the first LSTM layer. This makes the layer return all outputs, instead of returning the last one only, which is required when connecting two recurrent layers.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;create_rnn_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;embedding_matrix&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;input_len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Sequential&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Embedding&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;embedding_matrix&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;embedding_matrix&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;embeddings_initializer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Constant&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;embedding_matrix&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;input_length&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;input_len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;trainable&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mask_zero&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;LSTM&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;64&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;return_sequences&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;recurrent_dropout&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Dropout&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;LSTM&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;64&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;64&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;sigmoid&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;compile&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;adam&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;binary_crossentropy&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;metrics&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;accuracy&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;I also tried some versions of a bidirectional RNN and some versions of a RNN where the embeddings were learned instead of pretrained. None of them performed better on the validation set. Implementations of two of these models are available with the source code but omitted here.&lt;/p&gt;
&lt;h3 id=&#34;train-and-evaluate-models&#34;&gt;Train and Evaluate Models&lt;/h3&gt;
&lt;p&gt;I used a batch size of 200 for all training runs. I trained the baseline model for 500 epochs since it took a lot of epochs before the training loss converged and each epoch ran very quickly. The baseline model scored 77.16% accuracy on the validation set.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model_factory&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;create_baseline_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;embedding_matrix&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;INPUT_LENGTH&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;fit&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;200&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epochs&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;500&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;val_loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;val_accuracy&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;evaluate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;verbose&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;I only trained the RNN model for 30 epochs since the training loss converged much quicker. The RNN model scored 85.45% accuracy on the validation set. Since this is the final model I also evaluated it on the test set. The accuracy on the test set was 85.43%, which suggests that the model generalizes well (at least on highly polar movie reviews).&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model_factory&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;create_rnn_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;embedding_matrix&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;INPUT_LENGTH&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;fit&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;200&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epochs&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;30&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;val_loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;val_accuracy&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;evaluate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;verbose&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;test_loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;test_accuracy&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;evaluate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;verbose&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;Sentiment classification can be used for many applications, for example when analyzing customer reviews or survey responses. In marketing and business development one might be interested in extracting customer information from textual data, such as the percentage of people that are happy with your product. It can be used to assess the emotion of a user that inputs text or speech into an application, which could be of interest in all kinds of applications ranging from chat bots to medical self-help systems.&lt;/p&gt;
&lt;p&gt;In order to further improve the performance of the model there are more things to try that I didn&amp;rsquo;t bother with. One of which is the input length of sequences. I set it to 100, which means that only the first 100 words of each review is considered. Many reviews in the dataset are much longer, and using more words might improve the classification accuracy. However, this would make the training times slower. Another thing to try is to change the dimensionality of the word embeddings. The pretrained GloVe embeddings comes with vectors of dimensions 50, 100, 200, and 300, but I only tried the 100 dimensional ones.&lt;/p&gt;
&lt;p&gt;I hope this post helped you out, and feel free to send me any questions.&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Facial Landmark Detection</title>
      <link>https://cfml.se/facial-landmark-detection/</link>
      <pubDate>Fri, 05 Oct 2018 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/facial-landmark-detection/</guid>
      <description>&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/facial_landmark_detection/test_img_prediction.png&#34; alt=&#34;Facial landmarks&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;Landmark detection is a computer vision problem where an algorithm tries to find the locations of landmarks (also called keypoints) in an image. In facial landmark detection the landmarks correspond to facial features such as the center of the eyes or the edges of the mouth. After the algorithm predicts the location of the landmarks, they can be used for applications such as applying filters to the image (like Snapchat filters), detecting the emotional state of the person in the image, or classifying whether the person is male or female.&lt;/p&gt;
&lt;p&gt;In this post I will show you how I implemented facial landmark detection in Keras. The dataset I used can be found at &lt;a href=&#34;https://www.kaggle.com/c/facial-keypoints-detection/data&#34;&gt;https://www.kaggle.com/c/facial-keypoints-detection/data&lt;/a&gt;. It contains grayscale images of size (96, 96) and up to 15 landmarks for each image.&lt;/p&gt;
&lt;p&gt;The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/facial_landmark_detection&#34;&gt;https://github.com/CarlFredriksson/facial_landmark_detection&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;implementation&#34;&gt;Implementation&lt;/h2&gt;
&lt;h3 id=&#34;load-data&#34;&gt;Load Data&lt;/h3&gt;
&lt;p&gt;The dataset contains csv files for training and testing. The test data contains only images and no landmarks, thus is of no use except for manual visual testing. I discarded the test data and instead tested the final model with images I found on my own. The training data contains 7049 rows, where each row contains landmark coordinates separated by commas and image pixels separated by spaces. Unfortunately many of the rows have missing landmark columns and after discarding those rows, we are left with 2140 labeled examples. After shuffling, I split the labeled examples into a training set and a validation set. The fraction of examples that is put in the validation set is determined by the parameter &lt;code&gt;validation_split&lt;/code&gt;.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;os&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;numpy&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;pandas.io.parsers&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;read_csv&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;sklearn.utils&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shuffle&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fld_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;load_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;validation_split&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;load_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;validation_split&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Create path to csv file&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;cwd&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;getcwd&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;csv_path&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;join&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;cwd&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;data/training.csv&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Load data from csv file into data frame, drop all rows that have missing values&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;data_frame&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;read_csv&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;csv_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data_frame&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Image&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;count&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;data_frame&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data_frame&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dropna&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data_frame&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Image&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;count&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Convert the rows of the image column from pixel values separated by spaces to numpy arrays&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;data_frame&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Image&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data_frame&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Image&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;apply&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;k&#34;&gt;lambda&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;fromstring&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sep&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34; &amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Create numpy matrix from image column by stacking the rows vertically&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_data&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;vstack&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data_frame&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Image&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;values&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Normalize pixel values to (0, 1) range&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_data&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_data&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;255&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Convert to float32, which is the default for Keras&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_data&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;astype&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;float32&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Reshape each row from one dimensional arrays to (height, width, num_channels) = (96, 96, 1)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_data&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reshape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;96&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;96&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Extract labels representing the coordinates of facial landmarks&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_data&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data_frame&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data_frame&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;columns&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]]&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;values&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Normalize coordinates to (0, 1) range&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_data&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_data&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;96&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_data&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;astype&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;float32&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Shuffle data&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_data&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shuffle&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Split data into training set and validation set&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;split_index&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;validation_split&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;split_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;split_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;split_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;split_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;create-models&#34;&gt;Create Models&lt;/h3&gt;
&lt;p&gt;Landmark detection is a type of regression problem, since given an input image, the objective is to predict 30 numbers corresponding to the 15 two dimensional landmarks. To achieve this objective we will create a convolutional neural network (CNN). I experimented with different network architectures and chose the one that resulted in the lowest validation loss. The hyperparameters that I varied were: the number of convolutional layers, the filter size and number of filters of the convolutional layers, the number of fully connected layers and their size, the number of dropout layers (which ended up at zero). Many of the historically well-performing CNN models shrink the height and width of the layer inputs and grow the number of filters as the network gets deeper, which is also the case for the model that I ended up with. The height and width are unchanged by the convolutional layers as the &amp;ldquo;same&amp;rdquo; padding scheme is used. Thus the shrinking of height and width happens solely in the max-pooling layers. The final model is pretty simple and not that deep, which is great for training times.&lt;/p&gt;
&lt;p&gt;I also created a baseline model to compare with. The baseline model is a simple feedforward neural network with one hidden layer.&lt;/p&gt;
&lt;p&gt;Both models use an Adam optimizer and mean squared error as the loss function.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;keras.models&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Sequential&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;keras.layers&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Activation&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Flatten&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Conv2D&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;MaxPool2D&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Dropout&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;create_baseline_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Sequential&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Flatten&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;input_shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;96&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;96&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;512&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;30&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;compile&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;adam&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;mean_squared_error&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;create_cnn_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Sequential&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Conv2D&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;input_shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;96&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;96&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;same&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;MaxPool2D&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;pool_size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Conv2D&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;64&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;same&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;MaxPool2D&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;pool_size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Conv2D&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;128&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;same&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;MaxPool2D&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;pool_size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Flatten&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;512&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;30&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;compile&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;adam&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;mean_squared_error&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;train-models&#34;&gt;Train Models&lt;/h3&gt;
&lt;p&gt;I trained the models for 100 epochs using a batch size of 200. After training, the final loss values are saved in a text file, the histories of loss values are plotted together, and the model architectures and weights are saved. In Keras you can save a model in a single .h5 file that contains the architecture, weights, and optimizer state or you can save the architecture and weights separately. Since I won&amp;rsquo;t continue training the model after saving it I don&amp;rsquo;t need the optimizer state, and I opted to save the architecture and weights separately, which saves some disk space.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;fld_utils&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;model_factory&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fld_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;load_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;validation_split&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;NUM_EPOCHS&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;100&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Run models&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;models&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;model_names&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;models&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model_factory&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;create_baseline_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;model_names&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;baseline&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;models&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model_factory&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;create_cnn_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;model_names&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;cnn&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;fld_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run_models&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;models&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model_names&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NUM_EPOCHS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;run_models&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;models&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model_names&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_epochs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;results_file&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;open&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;output/results.txt&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;w&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;histories&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model_name&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;zip&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;models&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model_names&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Train model&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;history&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;fit&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;200&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epochs&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_epochs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;validation_data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;histories&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Evaluate&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;final_train_loss&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;evaluate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;verbose&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;final_val_loss&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;evaluate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;verbose&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;results_file&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;write&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model_name&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;&amp;gt; final_train_loss: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;final_train_loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;, final_val_loss: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;final_val_loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;&lt;/span&gt;&lt;span class=&#34;se&#34;&gt;\n&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Save model&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;save_weights&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;saved_models/&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model_name&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;_model_weights.h5&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;open&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;saved_models/&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model_name&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;_model_architecture.json&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;w&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;write&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;to_json&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plot_histories&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;histories&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model_names&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;histories.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;results_file&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;close&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;plot_histories&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;histories&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model_names&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model_name&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;zip&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;histories&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model_names&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;array&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;loss&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;label&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model_name&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34; train loss&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;array&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;val_loss&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;label&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model_name&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34; val loss&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;xlabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Epoch&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ylabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Loss (Mean squared error)&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;legend&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;yscale&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;log&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;savefig&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;output/&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;bbox_inches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;tight&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;clf&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/facial_landmark_detection/histories.png&#34; alt=&#34;Histories&#34;&gt;&lt;/p&gt;
&lt;h3 id=&#34;predict-landmarks&#34;&gt;Predict Landmarks&lt;/h3&gt;
&lt;p&gt;Let us plot some landmark predictions on a randomly chosen image from the validation set to get a sense of the performance of the models. The first image is always chosen, but since the data is shuffled before partitioned, it is a randomly chosen image. The landmarks are extracted from network output by grouping the 30 values into 15 two-dimensional points. Note that the network predictions needs to be multiplied by image size since the values we have trained with were normalized.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;numpy&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;fld_utils&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fld_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;load_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;validation_split&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;img_size_x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img_size_y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Plot correct landmarks&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fld_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;extract_landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;fld_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;save_img_with_landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;data_visual.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;gray_scale&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Baseline model&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fld_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;load_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;baseline&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;y_pred&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;predict&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fld_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;extract_landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;y_pred&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img_size_x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img_size_y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;fld_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;save_img_with_landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;baseline_prediction.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;gray_scale&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# CNN model&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fld_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;load_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;cnn&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;y_pred&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;predict&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fld_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;extract_landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;y_pred&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img_size_x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img_size_y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;fld_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;save_img_with_landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;cnn_prediction.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;gray_scale&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;plt&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;keras.models&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;load_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model_from_json&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;load_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;open&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;saved_models/&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model_name&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;_model_architecture.json&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;r&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model_from_json&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;read&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;load_weights&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;saved_models/&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model_name&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;_model_weights.h5&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;extract_landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;y_pred&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img_size_x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img_size_y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;y_pred&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;landmark_x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;landmark_y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;y_pred&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img_size_x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;y_pred&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;+&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img_size_y&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;landmark_x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;landmark_y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;save_img_with_landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;gray_scale&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;gray_scale&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;imshow&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;squeeze&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cmap&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_cmap&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;gray&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;imshow&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;squeeze&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;landmark&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;landmark&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;landmark&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;go&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;savefig&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;output/&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;bbox_inches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;tight&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;clf&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The correct landmarks:&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/facial_landmark_detection/data_visual.png&#34; alt=&#34;Correct landmarks&#34;&gt;&lt;/p&gt;
&lt;p&gt;The baseline prediction:&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/facial_landmark_detection/baseline_prediction.png&#34; alt=&#34;Baseline prediction&#34;&gt;&lt;/p&gt;
&lt;p&gt;The cnn prediction:&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/facial_landmark_detection/cnn_prediction.png&#34; alt=&#34;CNN prediction&#34;&gt;&lt;/p&gt;
&lt;p&gt;We can see that the baseline model performs horribly but the cnn model performs very well on the chosen image.&lt;/p&gt;
&lt;h3 id=&#34;sunglasses-filter&#34;&gt;Sunglasses Filter&lt;/h3&gt;
&lt;p&gt;Now we are finally ready to use our model in an application. I have created a simple script that applies a sunglasses filter to a face image. The predicted landmarks are used to scale, position, and rotate the sunglasses image. Since we have trained on normalized grayscale images of size (96, 96), we need to load the input image in grayscale mode, resize, and normalize it. To apply the filter we paste the processed sunglasses image on top of the original version of the input image that hasn&amp;rsquo;t had any preprocessing done to it. I normally use the library cv2 for image processing, but in this script PIL is also used. I used PIL because it has convenient functions for rotating images and pasting images on top of each other.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;numpy&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;cv2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;PIL&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Image&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;plt&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;fld_utils&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Load original image&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;face_img_path&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;input/picard.png&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;orig_img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cv2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;imread&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;face_img_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;orig_img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cv2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;cvtColor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;orig_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cv2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;COLOR_BGR2RGB&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;orig_size_x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;orig_size_y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;orig_img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;orig_img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Prepare input image&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cv2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;imread&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;face_img_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cv2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;IMREAD_GRAYSCALE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cv2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;resize&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dsize&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;96&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;96&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;interpolation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;cv2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;INTER_AREA&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;255&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;astype&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;float32&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Predict landmarks&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fld_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;load_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;cnn&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;y_pred&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;predict&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fld_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;extract_landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;y_pred&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;orig_size_x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;orig_size_y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Save original image with landmarks on top&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;fld_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;save_img_with_landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;orig_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;test_img_prediction.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Extract x and y values from landmarks of interest&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;left_eye_center_x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;][&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;left_eye_center_y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;][&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;right_eye_center_x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;][&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;right_eye_center_y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;][&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;left_eye_outer_x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;][&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;right_eye_outer_x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;landmarks&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;][&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Load images using PIL&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# PIL has better functions for rotating and pasting compared to cv2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;face_img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Image&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;open&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;face_img_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;sunglasses_img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Image&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;open&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;input/sunglasses.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Resize sunglasses&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;sunglasses_width&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;left_eye_outer_x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;right_eye_outer_x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;1.4&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;sunglasses_height&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sunglasses_img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sunglasses_width&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sunglasses_img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;sunglasses_resized&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sunglasses_img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;resize&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sunglasses_width&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sunglasses_height&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Rotate sunglasses&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;eye_angle_radians&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;arctan&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;right_eye_center_y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;left_eye_center_y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;left_eye_center_x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;right_eye_center_x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;sunglasses_rotated&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sunglasses_resized&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;rotate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;degrees&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;eye_angle_radians&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;expand&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;resample&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Image&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;BICUBIC&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Compute positions such that the center of the sunglasses is&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# positioned at the center point between the eyes&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;x_offset&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sunglasses_width&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;y_offset&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sunglasses_height&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;pos_x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;left_eye_center_x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;right_eye_center_x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x_offset&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;pos_y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;left_eye_center_y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;right_eye_center_y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;y_offset&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Paste sunglasses on face image&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;face_img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;paste&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sunglasses_rotated&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;pos_x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pos_y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sunglasses_rotated&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;face_img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;save&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;output/test_img_sunglasses.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/facial_landmark_detection/test_img_prediction.png&#34; alt=&#34;Test image prediction&#34;&gt;&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/facial_landmark_detection/test_img_sunglasses.png&#34; alt=&#34;Test image sunglasses&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;This project was very fun and I&amp;rsquo;m pleased how it turned out. I was impressed with the performance of the model given its simplicity. I did not spend a lot of time on hyperparameter tuning, and I might revisit this step if I use facial landmark detection in a project that needs a more accurate model. Besides tuning hyperparameters, performance could also be improved by training on more data. This could be done by finding a bigger dataset, labeling images on your own, and/or data augmentation. If data augmentation is done, one has to be careful in making sure that the images are still correctly labeled. For example, if an image is flipped horizontally the landmarks also has to be flipped in the same way. However, the order of the landmarks must be consistent between images. If a landmark that represents the center of the left eye is flipped, it now represents the center of the right eye. Thus for a flipped image, the order of landmarks that represent left and right versions of a facial feature has to be swapped to maintain consistency.&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Regression using Keras</title>
      <link>https://cfml.se/regression-using-keras/</link>
      <pubDate>Sat, 29 Sep 2018 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/regression-using-keras/</guid>
      <description>&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_using_keras/results_nn.png&#34; alt=&#34;NN Regression&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;I have started experimenting with Keras, which is a high-level neural networks API. It provides an interface on top of TensorFlow, CNTK, or Theano. So far have been using Keras with TensorFlow as a backend. Keras makes prototyping and experimentation faster and easier, while removing some of the flexibility that TensorFlow has. If you use the TensorFlow backend you can combine the two when the need arises for functionality that Keras does not provide on its own.&lt;/p&gt;
&lt;p&gt;In this blog post I am going to show you how to implement simple regression in Keras. We are implementing linear regression as a baseline model, and a small feedforward neural network as a more complex model. For more information about regression, and how to implement it in TensorFlow, check out my previous post &lt;a href=&#34;https://cfml.se/regression-techniques/&#34;&gt;Regression Techniques&lt;/a&gt;.&lt;/p&gt;
&lt;p&gt;The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/regression_using_keras&#34;&gt;https://github.com/CarlFredriksson/regression_using_keras&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;implementation&#34;&gt;Implementation&lt;/h2&gt;
&lt;h3 id=&#34;import-modules&#34;&gt;Import Modules&lt;/h3&gt;
&lt;p&gt;We will need to import the following modules:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;numpy&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;plt&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;keras.models&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Sequential&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;keras.layers&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;keras.optimizers&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;SGD&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;generate-data&#34;&gt;Generate Data&lt;/h3&gt;
&lt;p&gt;Since this project is only a proof of concept, I decided to use randomly generated datasets. The generated data is of the form &lt;code&gt;$y=x^2-2$&lt;/code&gt; for &lt;code&gt;$0 \leq x \leq 3$&lt;/code&gt; with some normally distributed random noise added to simulate real world data.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;generate_random_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;linspace&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;200&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;**&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;noise&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;normal&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;noise&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;astype&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;float32&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Let us generate datasets for training and validation.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;generate_random_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;generate_random_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;To visualize the datasets we can plot them using Matplotlib.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;plot_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;scatter&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;color&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;blue&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;grid&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;xlabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;x&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ylabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;y&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;savefig&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;output/&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;bbox_inches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;tight&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;clf&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;plot_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;data_train.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;plot_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;data_val.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_using_keras/data_train.png&#34; alt=&#34;Train data&#34;&gt;&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_using_keras/data_val.png&#34; alt=&#34;Val data&#34;&gt;&lt;/p&gt;
&lt;h3 id=&#34;create-baseline-model&#34;&gt;Create Baseline Model&lt;/h3&gt;
&lt;p&gt;To implement simple linear regression we can use a neural network without hidden layers. In Keras we use a single dense layer for this. A dense layer is a normal fully connected layer. Note that the first (and only layer in this case) of a sequential Keras model needs to specify the input shape. To finish creating a model we need to compile it, while specifying what optimizer we want and what loss function to use.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;create_baseline_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Sequential&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;input_shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;compile&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;SGD&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;lr&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.001&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;mean_squared_error&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;create-more-complex-model&#34;&gt;Create More Complex Model&lt;/h3&gt;
&lt;p&gt;The more complex model will contain three hidden layers of ten neurons each. Do not forget to add an activation function to the hidden layers.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;create_nn_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Sequential&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;input_shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;relu&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;add&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;compile&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;SGD&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;lr&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.001&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;mean_squared_error&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;train-models&#34;&gt;Train Models&lt;/h3&gt;
&lt;p&gt;Training a model in Keras is very simple with the &lt;code&gt;model.fit&lt;/code&gt; function. It is one of the nice high level functions that allows us to get up and runnning very quickly. Since the training set I generated is very small, I decided to use the whole set each batch.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;fit&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epochs&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10000&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;validation_data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The &lt;code&gt;model.fit&lt;/code&gt; function returns a history of the training and validation losses. To evaluate a model it is often a good idea to plot the history.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;plot_history&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;array&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;loss&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;label&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Train loss&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;array&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;val_loss&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;label&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Val loss&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;xlabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Epoch&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ylabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Loss (Mean squared error)&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;legend&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;savefig&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;output/&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;bbox_inches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;tight&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;clf&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;For the baseline model:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;plot_history&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;history_baseline.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_using_keras/history_baseline.png&#34; alt=&#34;History baseline&#34;&gt;&lt;/p&gt;
&lt;p&gt;For the more complex model:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;plot_history&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;history&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;history_nn.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_using_keras/history_nn.png&#34; alt=&#34;History nn&#34;&gt;&lt;/p&gt;
&lt;p&gt;We can also evaluate models using the &lt;code&gt;model.evaluate&lt;/code&gt; function. This returns the final loss for the specified dataset.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;final_train_loss&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;evaluate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;verbose&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;final_val_loss&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;evaluate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;verbose&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;For the run that resulted in the plots above, the final training and validation losses were 1.566 and 1.621 for the baseline model, 1.080 and 1.184 for the more complex model.&lt;/p&gt;
&lt;h3 id=&#34;using-the-trained-models&#34;&gt;Using the Trained Models&lt;/h3&gt;
&lt;p&gt;In order to output predictions for a given dataset we can use the &lt;code&gt;model.predict&lt;/code&gt; function.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;Y_predict&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;predict&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The predictions can be plotted on top of the data using the following function:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;plot_results&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_predict&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;scatter&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;color&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;blue&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_predict&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;color&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;red&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;grid&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;xlabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;x&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ylabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;y&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;savefig&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;output/&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;bbox_inches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;tight&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;clf&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;For the baseline model:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;plot_results&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_predict&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;results_baseline.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_using_keras/results_baseline.png&#34; alt=&#34;Results baseline&#34;&gt;&lt;/p&gt;
&lt;p&gt;For the more complex model:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;plot_results&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_val&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_predict&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;results_nn.png&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_using_keras/results_nn.png&#34; alt=&#34;Results baseline&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;Keras is a powerful library that lets developers iterate quickly. For this simple project there was no need for deep networks, but Keras was developed with deep learning in mind, and is well suited for creating deeper and more complex models.&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Digit Recognition</title>
      <link>https://cfml.se/digit-recognition/</link>
      <pubDate>Tue, 18 Sep 2018 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/digit-recognition/</guid>
      <description>&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/digit_recognition/top_img.png&#34; alt=&#34;Digit recognition&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;Digit recognition is a form of multi-class classification where the inputs are images of hand written digits, and the outputs are the corresponding numbers (0-9). Classifying handwritten text or numbers is important for many real world scenarios. For example, a postal service can scan postal codes on envelopes in order to automate grouping of envelopes that should be sent to the same city. Digit recognition is one of the simplest versions of classifying handwriting, and is a good first project for getting into the field of image classification.&lt;/p&gt;
&lt;p&gt;I have implemented a small convolutional neural network (CNN) in TensorFlow, that achieves a decent performance on the MNIST test set (≈1% classification error). MNIST is a database containing gray scale images of handwritten digits. The database contains a training set of 60000 examples and a test set of 10000 examples. Read more on &lt;a href=&#34;http://yann.lecun.com/exdb/mnist/&#34;&gt;http://yann.lecun.com/exdb/mnist/&lt;/a&gt;.&lt;/p&gt;
&lt;p&gt;The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/digit_recognition&#34;&gt;https://github.com/CarlFredriksson/digit_recognition&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;implementation&#34;&gt;Implementation&lt;/h2&gt;
&lt;h3 id=&#34;loading-the-mnist-data&#34;&gt;Loading the MNIST Data&lt;/h3&gt;
&lt;p&gt;You could download the data from &lt;a href=&#34;http://yann.lecun.com/exdb/mnist/&#34;&gt;http://yann.lecun.com/exdb/mnist/&lt;/a&gt;, but a convenient alternative is to import from &lt;code&gt;keras.datasets&lt;/code&gt;.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tensorflow.keras.datasets&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mnist&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;load_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mnist&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;load_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;In order to visualize the data I plotted the first four training examples.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;visualize_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;subplot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;221&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;imshow&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cmap&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_cmap&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;gray&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;title&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;y: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;subplot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;222&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;imshow&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cmap&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_cmap&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;gray&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;title&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;y: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;subplot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;223&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;imshow&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cmap&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_cmap&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;gray&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;title&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;y: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;subplot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;224&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;imshow&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cmap&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_cmap&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;gray&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;title&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;y: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;savefig&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;output/&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;bbox_inches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;tight&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;clf&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/digit_recognition/data_visual.png&#34; alt=&#34;Visualize data&#34;&gt;&lt;/p&gt;
&lt;h3 id=&#34;preprocessing&#34;&gt;Preprocessing&lt;/h3&gt;
&lt;p&gt;A standard prepropcessing step is to normalize the inputs, thus we divide all pixel values by 255 to put them in the (0-1) range. In the output layer we are going to use the softmax activation, thus we need to convert the outputs from (0-9) to one hot vectors. A one hot vector is a vector with a single 1 and 0 in the other positions. For example, the one hot vectors for 0, 3, and 9 are $[1,0,0,0,0,0,0,0,0,0]$, $[0,0,0,1,0,0,0,0,0,0]$, and $[0,0,0,0,0,0,0,0,0,1]$ respectively. Finally, since we are building a CNN, we also want to add a channels dimension to the input. We only need one channel since the images are in grayscale.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;preprocess_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Normalize image pixel values from 0-255 to 0-1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;255&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;255&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Change y values from 0-9 to one hot vectors&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;convert_to_one_hot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;convert_to_one_hot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Add channels dimension&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;convert_to_one_hot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_onehot&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;max&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_onehot&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;arange&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_onehot&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;To significantly speed up the training process, we are going to divide the training set into mini batches. Instead of doing one step of gradient descent after processing the whole training set, we are going to do one gradient descent step after each mini batch. I used a mini batch size of 200, which gives us 300 mini batches.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;random_mini_batches&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;m&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;c1&#34;&gt;# Number of training examples&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Shuffle training examples&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;permutation&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;list&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;permutation&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;m&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_shuffled&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;permutation&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_shuffled&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;permutation&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Partition into mini-batches&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;num_complete_mini_batches&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;math&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;floor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;m&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_complete_mini_batches&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;X_mini_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_shuffled&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batch_size&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Y_mini_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_shuffled&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batch_size&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batch_size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;mini_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_mini_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_mini_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mini_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Handling the case that the last mini-batch &amp;lt; mini_batch_size&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;m&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;%&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batch_size&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;!=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;X_mini_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_shuffled&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_complete_mini_batches&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batch_size&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;m&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Y_mini_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_shuffled&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_complete_mini_batches&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batch_size&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;m&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;mini_batch&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_mini_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_mini_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mini_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;creating-the-model&#34;&gt;Creating the Model&lt;/h3&gt;
&lt;p&gt;Now it is time to create the CNN model. To have a short training time I chose a small model with a single convolutional layer. The simple model yields pretty decent results (≈1% classification error), but deeper models can perform even better. A list of the best models can be found at &lt;a href=&#34;http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html&#34;&gt;http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html&lt;/a&gt;.&lt;/p&gt;
&lt;p&gt;The model can be summarized as: convolutional&amp;ndash;&amp;gt;max_pool&amp;ndash;&amp;gt;dropout&amp;ndash;&amp;gt;flatten&amp;ndash;&amp;gt;dense&amp;ndash;&amp;gt;dense&amp;ndash;&amp;gt;softmax. The dropout layer is used to reduce overfitting. The dense layers are fully connected layers. Note the placeholder &lt;code&gt;training_flag&lt;/code&gt; which will decide when the dropout should be applied. We only want dropo`$ut applied when training, thus the placeholder is set to false as a default.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;create_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;height&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;width&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;channels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_classes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reset_default_graph&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;height&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;width&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;channels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;name&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;X&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_classes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;name&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Y&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;training_flag&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder_with_default&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;conv1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2d&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;filters&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;kernel_size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;strides&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;same&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;pool1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;max_pooling2d&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pool_size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;strides&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;valid&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Dropout does not apply by default, training=True is needed to make the layer do anything&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# We only want dropout applied during training&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;dropout&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dropout&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;pool1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;rate&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;training&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;training_flag&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;flatten&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;flatten&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dropout&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;dense1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;flatten&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;128&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dense1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_classes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;activation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;softmax&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;name&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Y_hat&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Compute cost&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;J&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;compute_cost&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;training_flag&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;To compute the cost I used the standard softmax cost function. Let $n_c$ be the number of classes ($n_c=10$ in our case). For a single training example with correct one hot vector $y$ and softmax output $\hat{y}$, the cost is:&lt;/p&gt;
&lt;p&gt;$$
J(y,\hat{y}) = -\sum_{i=1}^{n_c} y_i \log{\hat{y}_i}
$$&lt;/p&gt;
&lt;p&gt;To compute the full cost we average over all training examples in the current mini batch. Note that a very small value is added to logarithm inputs, to avoid taking the log of 0.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;compute_cost&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Add small value epsilon to tf.log() calls to avoid taking the log of 0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;1e-10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;J&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reduce_mean&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reduce_sum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;log&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;name&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;J&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;training-the-model&#34;&gt;Training the Model&lt;/h3&gt;
&lt;p&gt;To train the model I used an Adam optimizer with a learning rate of 0.001. I trained for 10 epochs. After training, the graph and variables are saved as a TensorFlow SavedModel. Classification accuracy on the training and test sets are computed by dividing the correctly classified examples by the size of the data set. The model achieves a test set accuracy of about 99%, or in other words 1% error. Error is often used instead of accuracy when comparing models with very high accuracy.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;run_model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;training_flag&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;LEARNING_RATE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NUM_EPOCHS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Create train op&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;AdamOptimizer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;LEARNING_RATE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;train_op&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;minimize&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Start session&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Session&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Initialize variables&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;global_variables_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Training loop&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;NUM_EPOCHS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_mini_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_mini_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;mini_batches&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;n&#34;&gt;_&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train_op&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_mini_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_mini_batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;training_flag&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;epoch: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;, J_train: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Final costs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;J_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;J_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Compute training accuracy&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Y_pred&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;accuracy_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;compute_accuracy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_pred&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Compute test accuracy&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Y_pred&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;accuracy_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;compute_accuracy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_pred&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Save model&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;saved_model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;simple_save&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;saved_model&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;inputs&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;X&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;Y&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;},&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;outputs&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Y_hat&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;accuracy_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;accuracy_test&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;compute_accuracy&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_pred&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_real&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_pred&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;argmax&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_pred&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_real&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;argmax&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_real&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;num_correct&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_pred&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_real&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;accuracy&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_correct&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_real&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;accuracy&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;testing-on-my-own-handwritten-image&#34;&gt;Testing on My Own Handwritten Image&lt;/h3&gt;
&lt;p&gt;I made a simple black and white image in paint.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/digit_recognition/2_handwritten.png&#34; alt=&#34;Handwritten image&#34;&gt;&lt;/p&gt;
&lt;p&gt;The image needs to be loaded and preprocessed. The preprocessing involves inverting the colors, since all of the images in MNIST have a black background. The image also needs to be resized to (28,28) which is the size of the MNIST images. As before we divide by 255 to normalize input values, and add a channel dimension. This time we also add a dimension to the start, since the model expects sets of input images.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;load_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cv2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;imread&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cv2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;IMREAD_GRAYSCALE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;preprocess_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;invert_colors&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;invert_colors&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cv2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;bitwise_not&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cv2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;resize&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dsize&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;interpolation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;cv2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;INTER_CUBIC&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;255&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;To test the trained model on my own image I loaded the graph and variables from the SavedModel in another session. To do a prediction we need to get the value of the output &lt;code&gt;Y_hat/Softmax:0&lt;/code&gt;, with our handwritten image as input &lt;code&gt;X:0&lt;/code&gt;. The value of &lt;code&gt;Y:0&lt;/code&gt; is irrelevant at this point, we just need to provide something that fits the expected dimensions. The index &lt;code&gt;:0&lt;/code&gt; is added to names by TensorFlow to avoid duplicates.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Start session&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Session&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Load model&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;saved_model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;loader&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;load&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;saved_model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tag_constants&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;SERVING&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;saved_model&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Predict digit for test image&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;y_pred&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Y_hat/Softmax:0&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;X:0&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;Y:0&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;y_pred&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;argmax&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;y_pred&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Predicted digit: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;y_pred&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;As expected the output is &lt;code&gt;Predicted digit: 2&lt;/code&gt;.&lt;/p&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;I am satisfied about how this project went, and it impressed me that such a simple model can get such decent results. In fact, a simple feedforward neural network (FNN) can also achieve pretty respectable results on MNIST. The CNN model I showed is slightly more complex than the small FNN I tried, but I chose the CNN since the classification error was lower (the FNN got about 2%) and the training time was still very fast.&lt;/p&gt;
&lt;p&gt;The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/digit_recognition&#34;&gt;https://github.com/CarlFredriksson/digit_recognition&lt;/a&gt;.&lt;/p&gt;
&lt;p&gt;Thank you for reading, and feel free to send me any questions.&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Binary Classification</title>
      <link>https://cfml.se/binary-classification/</link>
      <pubDate>Thu, 06 Sep 2018 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/binary-classification/</guid>
      <description>&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/binary_classification/nn_db_non_linear_train.png&#34; alt=&#34;Binary Classification 2D&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;Classification is an important Machine Learning task where an algorithm learns to associate input features to a finite number of classes. For example, given pixels from an animal image as input, output what animal is in that image. Binary classification is a special case where the number of classes is two, for example true or false, good or bad, cat or no cat etc.&lt;/p&gt;
&lt;p&gt;I have implemented two binary classification models in TensorFlow, one using simple logistic regression and another using a shallow feedforward neural network (FNN). Logistic regression is a linear classifier, which means that it can only divide input data using a straight line (or plane/hyperplane depending on the dimensionality). Using a FNN with a logistic node as output allows us to divide data using more complex functions. I decided to only use two dimensional data for visualization purposes, but the code can easily be adapted to data with higher dimensionality.&lt;/p&gt;
&lt;p&gt;The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/binary_classification&#34;&gt;https://github.com/CarlFredriksson/binary_classification&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;implementation&#34;&gt;Implementation&lt;/h2&gt;
&lt;h3 id=&#34;import-modules&#34;&gt;Import Modules&lt;/h3&gt;
&lt;p&gt;We will need to import the following modules:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;os&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;numpy&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tensorflow&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tf&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;matplotlib&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;mpl&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;plt&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;generate-data&#34;&gt;Generate Data&lt;/h3&gt;
&lt;p&gt;Since this project is only a proof of concept, I decided to use generated training data instead of looking for some real world data. I generated linear data that can be divided using a straight line, and set one side of the line to be class 0 and the other class 1. Some data points had their classes flipped randomly in order to simulate noise that would exist in the real world. Classes were flipped with higher probability for the data points that were close to the boundary line.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;generate_linear_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_data_points&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;rand&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_data_points&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Set classes with the line x_2 = 0.5 as a boundary&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_data_points&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Flip the class of the points randomly, with higher probability if the point is close to the boundary&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;distance&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;abs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;flip_prob&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;max&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.3&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;distance&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;rand&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;flip_prob&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;%&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;I also implemented a function for generating non-linear data in the form of a circle. This time I did not bother with adding more noise, since logistic regression will be completely useless on this data set anyway, and it will be sufficient to show the difference between the classification models.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;generate_non_linear_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_data_points&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;rand&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_data_points&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Set classes with a ring centered at (0.5, 0.5) as a boundary&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;linalg&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;norm&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;array&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_data_points&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;Let us generate some training data that will be used to train the models.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/binary_classification/data_linear_train.png&#34; alt=&#34;Linear training data&#34;&gt;&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/binary_classification/data_non_linear_train.png&#34; alt=&#34;Non-linear training data&#34;&gt;&lt;/p&gt;
&lt;p&gt;Let us also generate some test data that will be used to evaluate the models.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/binary_classification/data_linear_test.png&#34; alt=&#34;Linear test data&#34;&gt;&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/binary_classification/data_non_linear_test.png&#34; alt=&#34;Non-linear test data&#34;&gt;&lt;/p&gt;
&lt;p&gt;I used the following function to plot the data sets:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;plot_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;colors&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;scatter&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;c&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;colors&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cmap&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mpl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;colors&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ListedColormap&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;red&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;blue&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;edgecolors&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;black&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;xlabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;x_1&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ylabel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;x_2&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;savefig&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;output/&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;bbox_inches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;tight&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;clf&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;logistic-regression&#34;&gt;Logistic Regression&lt;/h3&gt;
&lt;p&gt;The first model uses simple logistic regression which outputs $\hat{y} = \sigma(x^T w + b)$, where $\sigma$ is the sigmoid function, $x$ is an input vector, $w$ is a weight vector, and $b$ is a bias vector. We predict class 0 if $\hat{y} &amp;lt;= 0.5$ and class 1 if $\hat{y} &amp;gt; 0.5$. The objective is to find the best decision boundary. A decision boundary separates points in input space where on one side of the boundary one class is predicted, and on the other side another class is predicted. We want a cost function that leads to a decision boundary that predicts the correct class for as many training examples as possible. Let $y \in {0,1}$ be the correct class for a training example $x$, and $\hat{y} = \sigma(x^T w + b)$. The logistic cost function for one training example is:&lt;/p&gt;
&lt;p&gt;$$
L(y, \hat{y}) = -(y \log{\hat{y}} + (1 - y) \log{(1 - \hat{y})})
$$&lt;/p&gt;
&lt;p&gt;For a training set of many examples we want to compute the average cost. Let $X_{train}$ be the training set with $|X_{train}|$ number of training examples, and $Y_{train}$ the correct classes for those examples. We have the complete logistic cost function:&lt;/p&gt;
&lt;p&gt;$$
J(X_{train}, Y_{train}, \hat{y}) = \frac{1}{|X_{train}|} \sum_{x \in X_{train}} L(Y_{train}(x), \hat{y}(x))
$$&lt;/p&gt;
&lt;p&gt;The reason for using this cost function instead of the squared error cost, is that it is a convex function which leads to more effective gradient descent.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;logistic_regression_2D&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_epochs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;db_plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reset_default_graph&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Create parameters&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;W&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;W&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;initializer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;contrib&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;xavier_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;b&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;b&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;initializer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Forward propagation&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;name&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;X&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;name&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Y&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sigmoid&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;matmul&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;W&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;b&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Compute cost, add small value epsilon to tf.log() calls to avoid taking the log of 0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;1e-10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;J&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reduce_mean&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;log&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;log&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Create train op&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GradientDescentOptimizer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;train_op&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;minimize&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Start session&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Session&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Initialize variables&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;global_variables_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Training loop&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_epochs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train_op&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;J_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;%&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1000&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;i: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;, J_train: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Evaluate&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;J_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;J_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Plot decision boundary&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;predict_func&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;lambda&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_grid&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_grid&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;bc_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot_decision_boundary&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;predict_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;db_plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J_test&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;plot_decision_boundary&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;predict_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;interval&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;arange&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;1.1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.001&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_2&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;meshgrid&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;interval&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;interval&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_grid&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;c_&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ravel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ravel&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_grid&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;predict_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_grid&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;predictions_grid&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;array&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;round&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_grid&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;predictions_grid&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;predictions_grid&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reshape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;contourf&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;predictions_grid&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cmap&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mpl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;colors&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ListedColormap&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;#ff6868&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;#6875ff&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;scatter&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;c&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[:,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cmap&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mpl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;colors&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ListedColormap&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;red&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;blue&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;edgecolors&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;black&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;savefig&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;output/&amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;bbox_inches&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;tight&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;plt&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;clf&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;As one might expect, this simple model yielded a good result when applied to our generated linear data. The final cost after 20000 epochs and a learning rate of 0.1, was 0.2094 on the training set and 0.2434 on the test set.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/binary_classification/lr_db_linear_train.png&#34; alt=&#34;Logistic regression linear training set&#34;&gt;&lt;/p&gt;
&lt;p&gt;As we have already discussed, logistic regression is a linear classifier and should have horrible results on a non-linear data set. This was the case for our generated non-linear data. The algorithm could not find a good decision boundary, and instead predicted the same class for the whole domain. The final cost after 20000 epochs and a learning rate of 0.1, was 0.5892 on the training set and 0.5800 on the test set.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/binary_classification/lr_db_non_linear_train.png&#34; alt=&#34;Logistic regression non-linear training set&#34;&gt;&lt;/p&gt;
&lt;h3 id=&#34;neural-network-classification&#34;&gt;Neural Network Classification&lt;/h3&gt;
&lt;p&gt;The second model uses a small FNN. The network has two hidden layers with ten neurons each. ReLU is used as an activation function for all layers except the output layer, which uses the sigmoid function. Note that if the hidden layers are removed, we are left with simple logistic regression. I used the same cost function as in the first model.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;nn_binary_classification_2D&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_epochs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;db_plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reset_default_graph&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Create parameters&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;W_1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;W_1&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;initializer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;contrib&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;xavier_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;b_1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;b_1&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;initializer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;W_2&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;W_2&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;initializer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;contrib&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;xavier_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;b_2&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;b_2&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;initializer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;W_3&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;W_3&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;initializer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;contrib&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;xavier_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;b_3&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;b_3&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;initializer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Forward propagation&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;name&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;X&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;name&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Y&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;matmul&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;W_1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;b_1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;matmul&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;W_2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;b_2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sigmoid&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;matmul&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;W_3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;b_3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Compute cost, add small value epsilon to tf.log() calls to avoid taking the log of 0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;1e-10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;J&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reduce_mean&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;log&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;log&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epsilon&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Create train op&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GradientDescentOptimizer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;train_op&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;minimize&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Start session&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Session&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Initialize variables&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;global_variables_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Training loop&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_epochs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train_op&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;J_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;%&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1000&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;i: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;, J_train: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Evaluate&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;J_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;J_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Plot decision boundary&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;predict_func&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;lambda&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_grid&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_hat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_grid&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;bc_utils&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;plot_decision_boundary&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;predict_func&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;db_plot_name&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J_test&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The FNN model also yielded a good result when applied to our generated linear data. The final cost after 20000 epochs and a learning rate of 0.1, was 0.1979 on the training set and 0.2452 on the test set. It did overfit the training set more than the first model as the cost difference between the training set and the test set was greater, and we can see that the decision boundary was not linear. Some overfitting is to be expected since the FNN model is capable of more complex functions, and we did not use any anti-overfitting techniques such as regularization.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/binary_classification/nn_db_linear_train.png&#34; alt=&#34;NN classification linear training set&#34;&gt;&lt;/p&gt;
&lt;p&gt;The big difference in performance comes when we switch to the non-linear data set. The more flexible FNN model can learn to find an approximately circular decision boundary, and fit the training set very well. The final cost after 20000 epochs and a learning rate of 0.1, was 0.0048 on the training set and 0.1214 on the test set. This time both costs where smaller but the overfitting greater, which can be explained by the relatively small amount of noise in the generated non-linear data, thus allowing a better fit to the training set.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/binary_classification/nn_db_non_linear_train.png&#34; alt=&#34;NN classification non-linear training set&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;Logistic regression can be a good choice when you are sure that a linear decision boundary will be sufficient. Otherwise you should probably use a neural network or a similarly flexible model. You might have to put in some work in order to find a good network architecture and to avoid overfitting, but the resulting model should have better results on most non-linear data sets.&lt;/p&gt;
&lt;p&gt;The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/binary_classification&#34;&gt;https://github.com/CarlFredriksson/binary_classification&lt;/a&gt;.&lt;/p&gt;
&lt;p&gt;Thank you for reading, and feel free to send me any questions.&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Regression Techniques</title>
      <link>https://cfml.se/regression-techniques/</link>
      <pubDate>Thu, 30 Aug 2018 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/regression-techniques/</guid>
      <description>&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_techniques/linear_regression_train_1.png&#34; alt=&#34;Linear Regression 1D&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;Regression is an important part of Machine Learning, where one tries to find an underlying function that is a good fit to some noisy data. As an example, you might want to predict how much you could sell a house for given data from other house sales. The data would contain input/output pairs of house features (such as size, number of rooms, location, etc.) and corresponding price.&lt;/p&gt;
&lt;p&gt;I have implemented two models in TensorFlow, one using simple linear regression and another using a shallow feedforward neural network (FNN). For visualization purposes, I decided to only use one dimensional data, but the code can easily be adapted to data with multiple dimensions.&lt;/p&gt;
&lt;p&gt;The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/regression_techniques&#34;&gt;https://github.com/CarlFredriksson/regression_techniques&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;implementation&#34;&gt;Implementation&lt;/h2&gt;
&lt;h3 id=&#34;import-modules&#34;&gt;Import Modules&lt;/h3&gt;
&lt;p&gt;We will need to import the following modules:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;os&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;numpy&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tensorflow&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tf&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;plt&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;generate-data&#34;&gt;Generate Data&lt;/h3&gt;
&lt;p&gt;Since this project is only a proof of concept, I decided to use generated training data instead of looking for some real world data.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;generate_random_data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;b&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;non_linear&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;noise_stdev&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;s2&#34;&gt;&amp;#34;&amp;#34;&amp;#34;Generate data with random noise.&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;linspace&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;50&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;b&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;non_linear&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;**&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;b&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;noise&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;normal&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;noise_stdev&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;noise&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;astype&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;float32&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The function can generate linear data of the form $y=ax+b$ and non-linear data of the form $y=ax^2+b$. I used $a=1$ and $b=2$ for all data sets.&lt;/p&gt;
&lt;p&gt;Let us generate some training data that will be used to train the models.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_techniques/data_linear_train.png&#34; alt=&#34;Linear training data&#34;&gt;&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_techniques/data_non_linear_train.png&#34; alt=&#34;Non-linear training data&#34;&gt;&lt;/p&gt;
&lt;p&gt;Let us also generate some test data that will be used to evaluate the models.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_techniques/data_linear_test.png&#34; alt=&#34;Linear test data&#34;&gt;&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_techniques/data_non_linear_test.png&#34; alt=&#34;Non-linear test data&#34;&gt;&lt;/p&gt;
&lt;h3 id=&#34;linear-regression&#34;&gt;Linear Regression&lt;/h3&gt;
&lt;p&gt;The first model I implemented used simple linear regression. This technique assumes that the underlying function is of the form $y=ax+b$, and the objective is to find the values for $a$ and $b$ that best fit the data. We need to define a cost function that makes us achieve this objective, and a very common one is squared error cost. Let us define $m$ as the number of training examples, $X_{train}$ as the input features, and $Y_{train}$ as the correct outputs for those features. Both $X_{train}$ and $Y_{train}$ have the dimensions $(m,1)$. Let us also define $Y_{estimate}$ as the values we get using our estimated $a$ and $b$ values ($a_{estimate}$ and $b_{estimate}$). The cost is defined as:&lt;/p&gt;
&lt;p&gt;$$
J(X_{train}, Y_{train}, Y_{estimate}) = \frac{1}{m} \sum_{x \in X_{train}} (Y_{train}(x) - Y_{estimate}(x))^2
$$&lt;/p&gt;
&lt;p&gt;or in other words:&lt;/p&gt;
&lt;p&gt;$$
J(X_{train}, a, b, a_{estimate}, b_{estimate}) = \frac{1}{m} \sum_{x \in X_{train}} ((ax+b) - (a_{estimate}x+b_{estimate}))^2
$$&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;linear_regression_1D&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_iterations&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;y_est&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Create parameters&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;a_estimate&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;b_estimate&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Compute cost&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_placeholder&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;name&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;X_placeholder&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_placeholder&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;name&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Y_placeholder&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_estimate&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a_estimate&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_placeholder&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;b_estimate&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;J&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reduce_mean&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_placeholder&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_estimate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;**&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Create train op&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GradientDescentOptimizer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;train_op&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;minimize&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Start session&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Session&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Initialize variables&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;global_variables_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Training loop&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_iterations&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train_op&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Create estimated Y values&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;y_est&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_estimate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;J_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;y_est&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J_test&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;As one might expect, this simple model yielded a good result when applied to our generated linear data. It fit the training set pretty well with a final cost of 0.9422.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_techniques/linear_regression_train_1.png&#34; alt=&#34;Linear regression train 1&#34;&gt;&lt;/p&gt;
&lt;p&gt;More importantly, it also had good results on the test set with a final cost of 1.0851.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_techniques/linear_regression_test_1.png&#34; alt=&#34;Linear regression test 1&#34;&gt;&lt;/p&gt;
&lt;p&gt;However, when the model is used on a non-linear data set, it obviously fails miserably. On the training set the final cost was 53.182.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_techniques/linear_regression_train_2.png&#34; alt=&#34;Linear regression train 2&#34;&gt;&lt;/p&gt;
&lt;p&gt;On the test set the final cost was 62.886.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_techniques/linear_regression_test_2.png&#34; alt=&#34;Linear regression train 2&#34;&gt;&lt;/p&gt;
&lt;h3 id=&#34;neural-network-regression&#34;&gt;Neural Network Regression&lt;/h3&gt;
&lt;p&gt;The second model I implemented uses a small FNN. The network has two hidden layers with ten neurons each. ReLU is used as an activation function for all layers except the output layer, which has a single linear neuron. Note that if the hidden layers are removed, we are left with simple linear regression. I used the same cost function as in the first model.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;nn_regression_1D&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_iterations&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reset_default_graph&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;y_est&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Create parameters&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;W_1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;W_1&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;initializer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;contrib&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;xavier_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;b_1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;b_1&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;initializer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;W_2&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;W_2&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;initializer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;contrib&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;xavier_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;b_2&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;b_2&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;initializer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;W_3&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;W_3&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;initializer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;contrib&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;xavier_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;b_3&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;b_3&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;initializer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zeros_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Forward propagation&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X_placeholder&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;name&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;X_placeholder&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_placeholder&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;name&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;Y_placeholder&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;matmul&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;W_1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;b_1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;matmul&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;W_2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;b_2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;X&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Y_estimate&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;matmul&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;W_3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;b_3&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Compute cost&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;J&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reduce_mean&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_placeholder&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_estimate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;**&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Create training operation&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GradientDescentOptimizer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;learning_rate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;train_op&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;minimize&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Start session&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Session&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Initialize variables&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;global_variables_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Training loop&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_iterations&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train_op&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Create estimated Y values&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;y_est&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J_train&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Y_estimate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;J_test&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;X_placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;X_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_placeholder&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Y_test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;y_est&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J_train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J_test&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The FNN model also yielded a good result when applied to our generated linear data. It fit the training set even better than the first model with a final cost of 0.7937.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_techniques/nn_regression_train_1.png&#34; alt=&#34;NN regression train 1&#34;&gt;&lt;/p&gt;
&lt;p&gt;However, it had worse performance on the test set with a final cost of 1.1286, which is slightly higher than for the first model. This significant difference between performance on the training and test sets are due to overfitting. Overfitting can be combatted by introducing regularization/dropout, getting more data, and/or reducing the size of the network.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_techniques/nn_regression_test_1.png&#34; alt=&#34;NN regression test 1&#34;&gt;&lt;/p&gt;
&lt;p&gt;The second model shows its strength when used on the non-linear data set. Since it does not assume the form of the underlying function, it can fit to non-linear data as well. The final cost on the training set was 1.0929.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_techniques/nn_regression_train_2.png&#34; alt=&#34;NN regression train 2&#34;&gt;&lt;/p&gt;
&lt;p&gt;It once again did some overfitting, and the final cost on the test set was 1.6068.&lt;/p&gt;
&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/regression_techniques/nn_regression_test_2.png&#34; alt=&#34;NN regression test 2&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;A simple linear regression model can be a good choice when you are absolutely sure that the underlying function is linear. Otherwise I would use a neural network or a similarly flexible model. You might have to put in more work to find a good network architecture and to avoid overfitting, but you do not have to worry about the possible non-linearity of your data set.&lt;/p&gt;
&lt;p&gt;The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/regression_techniques&#34;&gt;https://github.com/CarlFredriksson/regression_techniques&lt;/a&gt;.&lt;/p&gt;
&lt;p&gt;Thank you for reading, and feel free to send me any questions.&lt;/p&gt;
</description>
    </item>
    
    
    
    <item>
      <title>Neural Style Transfer</title>
      <link>https://cfml.se/neural-style-transfer/</link>
      <pubDate>Mon, 20 Aug 2018 00:00:00 +0000</pubDate>
      
      <guid>https://cfml.se/neural-style-transfer/</guid>
      <description>&lt;p&gt;&lt;img src=&#34;https://cfml.se/images/neural_style_transfer/nst_examples.jpg&#34; alt=&#34;NST Examples&#34;&gt;&lt;/p&gt;
&lt;h2 id=&#34;introduction&#34;&gt;Introduction&lt;/h2&gt;
&lt;p&gt;Neural Style Transfer (NST) is one of the coolest techniques in deep learning. It is an algorithm that generates a new image starting from a content image (the cat in the image above) and a style image. The objective is for the generated image to contain the &amp;ldquo;content&amp;rdquo; of the content image, but have the same &amp;ldquo;style&amp;rdquo; as the style image. NST is a form of transfer learning which uses a pretrained convolutional neural network (CNN). Instead of training the weights of the network, we train the generated image, which is used as an input to network, until the objective is achieved sufficiently. The algorithm was created by &lt;a href=&#34;https://arxiv.org/abs/1508.06576&#34;&gt;Gatys et al. (2015)&lt;/a&gt;.&lt;/p&gt;
&lt;p&gt;I have implemented NST in TensorFlow using a pretrained VGG-19 network trained on the ImageNet dataset. To avoid having to search the Internet for pretrained weights and having to create the layers myself, I used a pretrained Keras model from keras.applications. I extracted the layers of the Keras model and used them to create a TensorFlow graph.&lt;/p&gt;
&lt;p&gt;The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/neural_style_transfer&#34;&gt;https://github.com/CarlFredriksson/neural_style_transfer&lt;/a&gt;.&lt;/p&gt;
&lt;h2 id=&#34;implementation&#34;&gt;Implementation&lt;/h2&gt;
&lt;h3 id=&#34;import-modules&#34;&gt;Import modules&lt;/h3&gt;
&lt;p&gt;To get started we need to import some modules.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;os&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;cv2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;numpy&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;np&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tensorflow&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tf&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tensorflow.keras&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;backend&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;K&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tensorflow.keras.applications&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;VGG19&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;tensorflow.keras.layers&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;MaxPooling2D&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;define-constants&#34;&gt;Define constants&lt;/h3&gt;
&lt;p&gt;Let us define some constants we will need.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;CONTENT_IMG_PATH&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;./input/cat.jpg&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;STYLE_IMG_PATH&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;./input/starry_night.jpg&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;GENERATED_IMG_PATH&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;./output/generated_img.jpg&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;IMG_SIZE&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;400&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;300&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;NUM_COLOR_CHANNELS&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;ALPHA&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;BETA&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;40&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;NOISE_RATIO&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.6&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;CONTENT_LAYER_INDEX&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;13&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;STYLE_LAYER_INDICES&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;4&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;7&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;12&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;17&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;STYLE_LAYER_COEFFICIENTS&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;0.2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;NUM_ITERATIONS&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;500&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;LEARNING_RATE&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;VGG_IMAGENET_MEANS&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;array&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;103.939&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;116.779&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mf&#34;&gt;123.68&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reshape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt; &lt;span class=&#34;c1&#34;&gt;# In blue-green-red order&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;LOG_GRAPH&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The purpose of some of these are self-apparent and the others will explained later on. Feel free to play around with different constant values and see how the results are affected.&lt;/p&gt;
&lt;h3 id=&#34;load-and-preprocess-images&#34;&gt;Load and preprocess images&lt;/h3&gt;
&lt;p&gt;We need to load the content and style images, and do some simple preprocessing before they can be used as input to the network. I used the library OpenCV (cv2) for image processing.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;load_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;color_means&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;s2&#34;&gt;&amp;#34;&amp;#34;&amp;#34;Load image from path, preprocess it, and return the image.&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cv2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;imread&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;cv2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;resize&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dsize&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;size&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;interpolation&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;cv2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;INTER_CUBIC&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;astype&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;float32&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;color_means&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;expand_dims&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;axis&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;content_img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;load_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;CONTENT_IMG_PATH&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;IMG_SIZE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;VGG_IMAGENET_MEANS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;style_img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;load_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;STYLE_IMG_PATH&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;IMG_SIZE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;VGG_IMAGENET_MEANS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The images are normalized using the means of each color channel for ImageNet, contained in the constant &lt;code&gt;VGG_IMAGENET_MEANS&lt;/code&gt;. Note that the order of the colors are BGR (Blue Green Red) and not RGB, because the network we will use is trained on BGR images and OpenCV uses BGR by default.&lt;/p&gt;
&lt;p&gt;We will also need a function for saving images. Remember to add the color means that were subtracted in preprocessing.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;save_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;color_means&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;s2&#34;&gt;&amp;#34;&amp;#34;&amp;#34;Save image to path after postprocessing.&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;color_means&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;clip&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;255&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;astype&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;uint8&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;cv2&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;imwrite&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;initialize-the-generated-image&#34;&gt;Initialize the generated image&lt;/h3&gt;
&lt;p&gt;The generated image is initialized as a combination of the content image and random noise. We could initialize to pure noise, but this would result in longer training time before the content of the generated image resembles the content of the content image.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;create_noisy_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;noise_ratio&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;s2&#34;&gt;&amp;#34;&amp;#34;&amp;#34;Add noise to img and return it.&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;noise&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;np&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;random&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;uniform&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;20&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;20&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;astype&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;float32&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;noisy_img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;noise_ratio&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;noise&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;noise_ratio&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;img&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;noisy_img&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;generated_img_init&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;create_noisy_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;content_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NOISE_RATIO&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;cost-function&#34;&gt;Cost function&lt;/h3&gt;
&lt;p&gt;In order to satisfy the objective of training a generated image to have the content of the content image and the style of the style image, we need to define a cost function that can be optimized. We will define two separate parts of this function, the content cost and the style cost, which we will combine to a total cost.&lt;/p&gt;
&lt;h4 id=&#34;content-cost&#34;&gt;Content cost&lt;/h4&gt;
&lt;p&gt;How can we make sure that the content in the generated image matches the content of the content image? In a CNN, the activations of shallow layers detect low level features such as edges. The activations of deep layers detect higher level features, such as objects like cats or cars. We will use deep layer activations to ensure that the high level features of the generated image are similar to the high level features of the content image.&lt;/p&gt;
&lt;p&gt;Let $a^{(c)}$ be the activations of some deep layer when the content image is used as input to the network, and $a^{(g)}$ be the activations when the generated image is used. Let $n_h, n_w, n_c$ be the height, width, and number of channels of the chosen layer. The content cost is defined as:&lt;/p&gt;
&lt;p&gt;$$
J_{content}(a^{(c)},a^{(g)}) = \frac{1}{4 \times n_h \times n_w \times n_c} \sum_{\text{all entries}}(a^{(c)} - a^{(g)})^2
$$&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;content_cost&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a_c&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a_g&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;s2&#34;&gt;&amp;#34;&amp;#34;&amp;#34;Return a tensor representing the content cost.&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;_&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n_h&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n_w&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n_c&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a_c&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;/&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;4&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n_h&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n_w&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n_c&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reduce_sum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;square&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;subtract&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a_c&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a_g&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)))&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h4 id=&#34;style-cost&#34;&gt;Style cost&lt;/h4&gt;
&lt;p&gt;The next question is how can we make sure that the style of the generated image matches the style of the style image? The style cost function is a bit more complicated than the content cost function. To start of we need to know how to compute the gram matrix for a set of vectors $(v_1,\dots,v_n)$. The entries $G_{ij}$ of a gram matrix are computed by:&lt;/p&gt;
&lt;p&gt;$$
G_{ij} = v_i \cdot v_j = v_i^T v_j
$$&lt;/p&gt;
&lt;p&gt;Intuitively $G_{ij}$ is one way to measure how similar $v_i$ is to $v_j$, since if they are similar their dot product should be large. We will compute gram matrices by unrolling activations for some layer into a matrix with one row for each channel of that layer, and then taking the matrix product between the matrix and its transpose. The resulting gram matrix $G$ is also called a style matrix and has dimensions $(n_c,n_c)$. As an example of what the style matrix captures: if the channel (also called filter) $i$ is detecting vertical textures, and $j$ is detecting horizontal textures, then $G_{ij}$ measure how much vertical and horizontal textures occur together in the image.&lt;/p&gt;
&lt;p&gt;Using the gram matrices $G^{(s)}$ and $G^{(g)}$ for the style and generated images respectively, the style cost for a given layer $l$ is defined as:&lt;/p&gt;
&lt;p&gt;$$
J_{style}^{[l]}(G^{(s)},G^{(g)}) = \frac{1}{4 \times n_c^2 \times (n_h \times n_w)^2} \sum_{i=1}^{n_c} \sum_{j=1}^{n_c} {(G_{ij}^{(s)} - G_{ij}^{(g)})^2}
$$&lt;/p&gt;
&lt;p&gt;Unlike the content cost, the style cost normally uses several layers with corresponding coefficients $\lambda^{[l]}$:&lt;/p&gt;
&lt;p&gt;$$
J_{style} = \sum_l \lambda^{[l]} J_{style}^{[l]}
$$&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;style_cost&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a_s_layers&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a_g_layers&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;style_layer_coefficients&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;s2&#34;&gt;&amp;#34;&amp;#34;&amp;#34;Return a tensor representing the style cost.&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;style_cost&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a_s_layers&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Compute gram matrix for the activations of the style image&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;a_s&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a_s_layers&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;_&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n_h&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n_w&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n_c&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a_s&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;shape&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;a_s_unrolled&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reshape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;transpose&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a_s&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n_c&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n_h&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;*&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n_w&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;a_s_gram&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;matmul&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a_s_unrolled&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;transpose&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a_s_unrolled&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Compute gram matrix for the activations of the generated image&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;a_g&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a_g_layers&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;a_g_unrolled&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reshape&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;transpose&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a_g&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n_c&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n_h&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;*&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n_w&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;a_g_gram&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;matmul&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a_g_unrolled&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;transpose&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a_g_unrolled&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Compute style cost for the current layer&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;style_cost_layer&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;/&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;4&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n_c&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;**&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;n_w&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n_h&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;**&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;reduce_sum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;square&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;subtract&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a_s_gram&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;a_g_gram&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;style_cost&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;style_cost_layer&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;style_layer_coefficients&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;style_cost&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h4 id=&#34;total-cost&#34;&gt;Total cost&lt;/h4&gt;
&lt;p&gt;Let $\alpha$ and $\beta$ be constants. The total cost is defined as a linear combination of the content cost and style cost:&lt;/p&gt;
&lt;p&gt;$$
J = \alpha J_{content} + \beta J_{style}
$$&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;total_cost&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;content_cost&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;style_cost&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;alpha&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;beta&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;s2&#34;&gt;&amp;#34;&amp;#34;&amp;#34;Return a tensor representing the total cost.&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;alpha&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;content_cost&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;beta&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;style_cost&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;create-graph&#34;&gt;Create graph&lt;/h3&gt;
&lt;p&gt;Now it is time to create the computation graph and store the tensors that corresponds to the layer activations we are interested in. Let us start by creating a TensorFlow variable which we will use as an input to the network.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;input_var&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;content_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dtype&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;float32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;expected_shape&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;kc&#34;&gt;None&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;NUM_COLOR_CHANNELS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;name&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;input_var&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The variable is initialized to the content image, and will later be assigned the style and generated images. Now let us get the pretrained VGG-19 Keras model, extract the layers, and store the output tensors we will need. Note that the MaxPooling layers are swapped for AvgPooling layers. This gives NST better performance.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;create_output_tensors&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;input_variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;content_layer_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;style_layer_indices&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;s2&#34;&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;s2&#34;&gt;    Create output tensors, using a pretrained Keras VGG19-model.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;s2&#34;&gt;    Return tensors for content and style layers.
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;s2&#34;&gt;    &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;vgg_model&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;VGG19&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;weights&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;imagenet&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;include_top&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;l&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;l&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;vgg_model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;](&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;input_variable&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;x_content_tensor&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;x_style_tensors&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;style_layer_indices&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;x_style_tensors&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Use layers from vgg model, but swap max pooling layers for average pooling&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;type&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;MaxPooling2D&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;avg_pool&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;ksize&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;strides&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;padding&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;SAME&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;else&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;layers&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;](&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Store appropriate layer outputs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;content_layer_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;x_content_tensor&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;style_layer_indices&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;x_style_tensors&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x_content_tensor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x_style_tensors&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;x_content&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x_styles&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;create_output_tensors&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;input_var&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;CONTENT_LAYER_INDEX&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;STYLE_LAYER_INDICES&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;p&gt;The constant &lt;code&gt;CONTENT_LAYER_INDEX=13&lt;/code&gt; makes us use the activations of layer &lt;code&gt;block4_conv2&lt;/code&gt; for computing content cost. The constant &lt;code&gt;STYLE_LAYER_INDICES = [1, 4, 7, 12, 17]&lt;/code&gt; makes us use the activations of the layers &lt;code&gt;block1_conv1&lt;/code&gt;, &lt;code&gt;block2_conv1&lt;/code&gt;, &lt;code&gt;block3_conv1&lt;/code&gt;, &lt;code&gt;block4_conv1&lt;/code&gt;, and &lt;code&gt;block5_conv1&lt;/code&gt; for computing style cost. Feel free to use other layers, or change the &lt;code&gt;STYLE_LAYER_COEFFICIENTS&lt;/code&gt;. You can use &lt;code&gt;vgg_model.summary()&lt;/code&gt; to see the available layers.&lt;/p&gt;
&lt;h3 id=&#34;create-training-operation&#34;&gt;Create training operation&lt;/h3&gt;
&lt;p&gt;Now it is time to create the training operation. Note that the Keras session is used instead of &lt;code&gt;tf.Session()&lt;/code&gt;, this is because we have used a Keras model and the weights for that model is contained in the Keras session. It is important to not use &lt;code&gt;tf.global_variables_initializer()&lt;/code&gt; in this case, since the pretrained weights would be randomized. Instead we use &lt;code&gt;tf.variables_initializer([input_var])&lt;/code&gt;. Also note that the &lt;code&gt;AdamOptimizer&lt;/code&gt; has variables that we initialize using &lt;code&gt;tf.variables_initializer(optimizer.variables())&lt;/code&gt;.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;AdamOptimizer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;LEARNING_RATE&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Use the Keras session instead of creating a new one&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;K&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;get_session&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;variables_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;input_var&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Extract the layer activations for content and style images&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;a_content&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x_content&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;K&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;learning_phase&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;input_var&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;assign&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;style_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;a_styles&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;x_styles&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feed_dict&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;K&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;learning_phase&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Define the cost function&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;J_content&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;content_cost&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a_content&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x_content&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;J_style&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;style_cost&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;a_styles&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;x_styles&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;STYLE_LAYER_COEFFICIENTS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;J_total&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;total_cost&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J_content&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;J_style&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;ALPHA&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;BETA&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Log the graph. To display use &amp;#34;tensorboard --logdir=log&amp;#34;.&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;LOG_GRAPH&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;writer&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;summary&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;FileWriter&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;log&amp;#34;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;graph&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;writer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;close&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Assign the generated random initial image as input&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;input_var&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;assign&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;generated_img_init&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Create the training operation&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;train_op&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;minimize&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J_total&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;var_list&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;input_var&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tf&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;variables_initializer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;variables&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()))&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h3 id=&#34;train-the-generated-image&#34;&gt;Train the generated image&lt;/h3&gt;
&lt;p&gt;It is finally time to train the generated image! The generated image is saved every 20th iteration so you can see the progress. The following code should be inside the &lt;code&gt;with K.get_session() as sess:&lt;/code&gt; block created above.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;NUM_ITERATIONS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train_op&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;if&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;%&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;20&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;s2&#34;&gt;&amp;#34;Iteration: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;str&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;s2&#34;&gt;&amp;#34;, Content cost: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;&lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{:.2e}&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;format&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J_content&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;s2&#34;&gt;&amp;#34;, Style cost: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;&lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{:.2e}&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;format&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J_style&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;s2&#34;&gt;&amp;#34;, Total cost: &amp;#34;&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;s2&#34;&gt;&amp;#34;&lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{:.2e}&lt;/span&gt;&lt;span class=&#34;s2&#34;&gt;&amp;#34;&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;format&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;J_total&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Save the generated image&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;generated_img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;input_var&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;save_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;generated_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;GENERATED_IMG_PATH&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;VGG_IMAGENET_MEANS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Save the generated image&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;generated_img&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;sess&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;run&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;input_var&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;save_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;generated_img&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;GENERATED_IMG_PATH&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;VGG_IMAGENET_MEANS&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;
&lt;h2 id=&#34;conclusion&#34;&gt;Conclusion&lt;/h2&gt;
&lt;p&gt;I am happy about how the project turned out. Neural Style Transfer is a really cool technique, and it is always nice with visual results that can be appreciated by non-ML people!&lt;/p&gt;
&lt;p&gt;The way I extracted the layers from the pretrained Keras model does seem slightly &amp;ldquo;hacky&amp;rdquo;, but overall I feel like it is a decent alternative that I probably will use again.&lt;/p&gt;
&lt;p&gt;The source code can be found at &lt;a href=&#34;https://github.com/CarlFredriksson/neural_style_transfer&#34;&gt;https://github.com/CarlFredriksson/neural_style_transfer&lt;/a&gt;.&lt;/p&gt;
&lt;p&gt;Thank you for reading, and feel free to send me any questions.&lt;/p&gt;
</description>
    </item>
    
    
  </channel>
</rss>
