#!/usr/bin/python


import sys, os
import getopt
import random
from time import clock, time, sleep
from psycopg import connect, ProgrammingError


# parameters
DSN_DB = ""
DSN_HOST = None
CONN_COUNT = 1
ITER_COUNT = 1
ITER_SLEEP = 0
ROW_COUNT = 1

# global data
MAX = 0
MIN = 99


def usage (appname):
    print "Usage: ", os.path.basename(appname), \
"""[-c count] [-i count] [-b size] [-h host] -d dbname
DB performance test script

  -c count      number of connections to test database
  -i count      number of iterations during test-loop per connection
  -r count      number of rows to return in each SETOF
  -h host       host name
  -d dbname     database name
  -s seconds    time to sleep after each query
"""
    sys.exit(2)


def parse_args (argv):
    global DSN_DB, DSN_HOST, CONN_COUNT, ITER_COUNT, ROW_COUNT
    try:
        opts, args = getopt.getopt(argv[1:], "c:i:r:h:d:")
    except getopt.GetoptError, e:
        print e, '\n'
        usage(argv[0])

    for o, a in opts:
        if o == "-c":
            # number of connections to test database
            CONN_COUNT = int(a)
        elif o == "-i":
            # number of iterations during test-loop per connection
            ITER_COUNT = int(a)
        elif o == "-r":
            # number of rows to return in each SETOF
            ROW_COUNT = int(a)
        elif o == "-h":
            # set database host name
            DSN_HOST = a
        elif o == "-d":
            # set data source name
            DSN_DB = a
        elif o == "-s":
            # set data source name
            ITER_SLEEP = a

    if DSN_DB == None:
        print "No database name specified (-d dbname)\n"
        usage(argv[0])



def test_func (cursor, count):
    global MAX, MIN
    begin = clock()

    try:
        # send data and receive
        cursor.execute("select * from test_setof_tuple(%d, " +
                "'(data,12345678,data,1,2,3,4,5,6)')",
                [count])
    except ProgrammingError:
        return

    result = cursor.fetchall()

    # measure time
    duration = clock() - begin
    if duration > MAX:
        MAX = duration
    if duration < MIN:
        MIN = duration


def main (argv):
    parse_args(argv)

    global DSN, DSN_DB, DSN_HOST
    DSN = "dbname=" + DSN_DB
    if DSN_HOST != None:
        DSN = "dbhost=" + DSN_HOST + " " + DSN

    # fork specified number of processes
    is_main = False
    N = 0
    for p in range(CONN_COUNT - 1):
        N += 1
        is_main = os.fork() != 0
        if not is_main:
            # children will begin tests, parent continues creating other
            # childrens
            break

    if is_main:
        total = time()

    # run and calculate timings of test function
    begin = clock()
    random.seed(os.getpid())
    db = connect(DSN)
    try:
        for n in range(ITER_COUNT):
            cursor = db.cursor()
            test_func(cursor, ROW_COUNT)
            db.commit()
            print "[%d:%d]" % (N, n),
            sys.stdout.flush()
            sleep(ITER_SLEEP)
    except KeyboardInterrupt:
        pass
    duration = clock() - begin
    db.close()

    # print statistics
    print "\ntotal=%.2f, min=%.2f, avg=%.2f, max=%.2f" % \
            (duration, MIN, duration/ITER_COUNT, MAX)

    # wait all childrens to finish
    if is_main:
        for p in range(CONN_COUNT - 1):
            os.waitpid(-1, 0)
        print "total test time took", time() - total, "second(s)"


if __name__ == "__main__":
    main(sys.argv)
