#! /usr/bin/env python

import unittest
from httplib import OK, BAD_REQUEST

import requests
import json
from subprocess import Popen
import time


class TestOnosDistributedManager(unittest.TestCase):

    base_url = 'http://localhost:5000/'
    cluster_json_url = base_url + 'cluster.json'

    def setUp(self):
        command = ["onos-distributed-manager"]
        self.server_process = Popen(command)
        time.sleep(1)

    def tearDown(self):
        requests.put(self.base_url + 'exit')

    def test_operations(self):
        node1_json = '{"id":"node1","ip":"10.128.11.161","port":9876}'
        node2_json = '{"id":"node2","ip":"10.128.11.162","port":9876}'
        node3_json = '{"id":"node3","ip":"10.128.11.163","port":9876}'
        node4_json = '{"id":"node4","ip":"10.128.11.164","port":9876}'
        node5_json = '{"id":"node5","ip":"10.128.11.165","port":9876}'
        new_node5_json = '{"id":"node5","ip":"10.128.11.265","port":1234}'

        node1_dict = json.loads(node1_json)
        new_node5_dict = json.loads(new_node5_json)

        post_headers = { 'Content-Type': 'application/json' }

        def check_list_lengths(expected_length):
            json_request = requests.get(self.cluster_json_url)
            self.assertEqual(json_request.status_code, OK)
            json_object = json_request.json()
            self.assertEqual(len(json_object["nodes"]), expected_length)
            self.assertEqual(len(json_object["partitions"]), expected_length)

        def post_node(node_json):
            request = requests.post(self.base_url,
                                    data=node_json,
                                    headers=post_headers)
            self.assertEqual(request.status_code, OK)

        # Test initial configuration - node list should be empty
        check_list_lengths(0)

        # Test post operation to add a node
        post_node(node1_json)

        json_request = requests.get(self.cluster_json_url)
        self.assertEqual(json_request.status_code, OK)
        json_object = json_request.json()
        self.assertEqual(len(json_object["nodes"]), 1)
        self.assertEqual(len(json_object["partitions"]), 1)
        self.assertIn(node1_dict, json_object["nodes"])
        self.assertIn(node1_dict["id"], json_object["partitions"][0]["members"])

        # Re-posting the same node should cause an error
        json_request = requests.post(self.base_url,
                                     data=node1_json,
                                     headers=post_headers)
        self.assertEqual(json_request.status_code, BAD_REQUEST)
        check_list_lengths(1)

        # Add 4 more nodes and check the partitions generated
        post_node(node2_json)
        post_node(node3_json)
        post_node(node4_json)
        post_node(node5_json)

        check_list_lengths(5)
        json_request = requests.get(self.cluster_json_url)
        self.assertEqual(json_request.status_code, OK)
        json_object = json_request.json()
        self.assertEqual(["node4", "node5", "node1"],
                         json_object["partitions"][3]["members"])

        # modify a node
        json_request = requests.put(self.base_url,
                                    data=new_node5_json,
                                    headers=post_headers)
        self.assertEqual(json_request.status_code, OK)
        json_request = requests.get(self.cluster_json_url)
        self.assertEqual(json_request.status_code, OK)
        json_object = json_request.json()
        self.assertIn(new_node5_dict, json_object["nodes"])

        # delete all the nodes
        def delete_node(json_string):
            request = requests.delete(self.base_url,
                                      data=json_string,
                                      headers=post_headers)
            self.assertEqual(request.status_code, OK)

        delete_node(new_node5_json)
        delete_node(node4_json)
        delete_node(node3_json)
        delete_node(node2_json)
        delete_node(node1_json)

        # make sure that no nodes remain
        check_list_lengths(0)

if __name__ == '__main__':
    unittest.main()
