Solving XOR with a Neural Network in Python

In the previous few posts, I detailed a simple neural network to solve the XOR problem in a nice handy package called Octave.

I find Octave quite useful as it is built to do linear algebra and matrix operations, both of which are crucial to standard feed-forward multi-layer neural networks. However, it isn’t that fast and you would not be building any deep-learning models on large datasets using it.

Coding in Python

There is also a numerical operation library available in Python called NumPy. This library has found widespread use in building neural networks, so I wanted to compare a similar network using it to a network in Octave.

The last post showed an Octave function to solve the XOR problem. Recall the problem was that we wanted to have a neural network correctly generate an output of zero when x1 and x2 are the same (the yellow circles) and output of one when x1 and x2 are different (the blue circles):

xor graph

Here is the topology of the network we want to train:

network

Lastly, here is a function in Python that is equivalent to the Octave xor_nn function. The code also includes a sigmoid function:

import numpy as np
import math

def sigmoid(x):
        return 1.0 / (1.0 + np.exp(-x))

def xor_nn(XOR, THETA1, THETA2, init_w=0, learn=0, alpha=0.01):
        if init_w == 1:
                THETA1 = 2*np.random.random([2,3]) - 1
                THETA2 = 2*np.random.random([1,3]) - 1

        T1_DELTA = np.zeros(THETA1.shape)
        T2_DELTA = np.zeros(THETA2.shape)
        m = 0
        J = 0.0

        for x in XOR:
                A1 = np.vstack(([1], np.transpose(x[0:2][np.newaxis])))
                Z2 = np.dot(THETA1, A1)
                A2 = np.vstack(([1], sigmoid(Z2)))
                Z3 = np.dot(THETA2, A2)
                h = sigmoid(Z3)

                J = J + (x[2] * math.log(h[0])) + ((1 - x[2]) * math.log(1 - h[0]));
                m = m + 1;
                if learn == 1:
                        delta3 = h - x[2]
                        delta2 = (np.dot(np.transpose(THETA2), delta3) * (A2 * (1 - A2)))[1:]
                        T2_DELTA = T2_DELTA + np.dot(delta3, np.transpose(A2))
                        T1_DELTA = T1_DELTA + np.dot(delta2, np.transpose(A1))
                else:
                        print(h)

        J = J / -m

        if learn == 1:
                THETA1 = THETA1 - (alpha * (T1_DELTA / m))
                THETA2 = THETA2 - (alpha * (T2_DELTA / m))
        else:
                print(J)

        return (THETA1, THETA2)

(This code is available on Github if you want to download it: Python NN on GitHub)

If you want more detail on how this function works, have a look back at Part 1, Part 2 and Part 3 of the series on the Octave version.

Comparing Python and Octave

To be sure that they both operate identically, I first generated some random numbers. These numbers were used to initialize the parameters (THETA1 and THETA2) in both functions. If you run several epochs (I ran 1000), then you will see that the values of THETA1 and THETA2 remain identical in Octave and Python. This makes sense as the only non-deterministic part of the algorithm is the initialization of the network’s parameters (i.e. the weights).

So we can be sure that both functions are executing the same algorithm.

The next step is to test how fast the networks run a large number of epochs. On my (ageing) MacBook, the Octave code runs 1000 epochs in about 9.5 seconds on average, while the Python code runs the same number in just 5.4 seconds. This is a pretty good performance improvement for what is practically the same code.

So if you are familiar with Python and want to start developing your own neural networks, then NumPy will give you the tools you need.

 

 

Advertisements

One thought on “Solving XOR with a Neural Network in Python

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s