view shelltools/query-pr/query.py @ 56:42d7888272a0 default tip

Implement fetch_classifications().
author David A. Holland
date Sun, 10 Apr 2022 19:37:18 -0400
parents 40f64a96481f
children
line wrap: on
line source

#!@PYTHON@

import sys
import argparse
import psycopg2

program_version = "@VERSION@"

############################################################
# settings

outfile = sys.stdout
outmaterial = "headers"
outformat = "text"

############################################################
# simple dump widget

class Dumper:
        def __init__(self, f):
                self.f = f
                self.indentation = 0
        def indent(self):
                self.indentation += 3
        def unindent(self):
                self.indentation -= 3
        def write(self, msg):
                self.f.write(" " * self.indentation + msg + "\n")

############################################################
# database field access

#
# Fields of PRs that we might search are spread across a number of
# tables and require varying joins to get them. And, because of
# classication schemes, the set of fields isn't static and we can't
# just assemble a massive view with one column for each field.
#
# The QueryBuilder class knows how to arrange for all known fields to
# be present.
#

# loaded below
hierclasses = []
flatclasses = []
textclasses = []
tagclasses = []

class QueryBuilder:
        # these fields are in the PRs table
        prtable_fields = [
                "id", "synopsis", "confidential", "state", "locked",
                "timeout_date", "timeout_state",
                "arrival_schemaversion", "arrival_date", "modified_date",
                "closed_date",
                "release", "environment"
        ]

        # these fields are aliases for others
        alias_fields = {
                "number" : "id",
                "date" : "arrival_date",
        }

        def __init__(self):
                self.present = {}
                self.joined = {}
                self.fromitems = []
                self.whereitems = []
                self.order = None

        def setorder(self, order):
                self.order = order

        # add to present{} and return the value for convenience (internal)
        def makepresent(self, field, name):
                self.present[field] = name
                return name

        # add a join item (once only) (internal)
        def addjoin(self, table, as_ = None):
                if as_ is not None:
                        key = table + "-" + as_
                        val = table + " AS " + as_
                else:
                        key = table
                        val = table
                if key not in self.joined:
                        self.joined[key] = True
                        self.fromitems.append(val)

        # returns a sql expression for the field
        def getfield(self, field):
                # already-fetched fields
                if field in self.present:
                        return self.present[field]

                # aliases for other fields
                if field in alias_fields:
                        return self.getfield(alias_fields[field])

                # simple fields directly in the PRs table
                if field in prtable_fields:
                        self.addjoin("PRs")
                        return self.makepresent(field, "PRs." + field)

                # now it gets more interesting...
                if field == "closed":
                        self.addjoin("PRs")
                        self.addjoin("states")
                        self.addwhere("PRs.state = states.name")
                        return self.makepresent(field, "states.closed")

                # XXX let's pick one set of names and use them everywhere
                # (e.g. change "posttime" in the schema to "message_date"
                # or something)
                if field == "comment_date" or field == "posttime":
                        self.addjoin("PRs")
                        self.addjoin("messages")
                        self.addwhere("PRs.id = messages.pr")
                        return self.makepresent(field, "messages.posttime")

                if field == "comment" or field == "message" or field == "post":
                        self.addjoin("PRs")
                        self.addjoin("messages")
                        self.addwhere("PRs.id = messages.pr")
                        return self.makepresent(field, "messages.body")

                if field == "attachment":
                        self.addjoin("PRs")
                        self.addjoin("messages")
                        self.addjoin("attachments")
                        self.addwhere("PRs.id = messages.pr")
                        self.addwhere("messages.id = attachments.msgid")
                        return self.makepresent(field, "attachments.body")

                if field == "patch":
                        self.addjoin("PRs")
                        self.addjoin("messages")
                        self.addjoin("attachments", "patches")
                        self.addwhere("PRs.id = messages.pr")
                        self.addwhere("messages.id = patches.msgid")
                        self.addwhere("patches.mimetype = " +
                                        "'application/x-patch'")
                        return self.makepresent(field, "patches.body")

                if field == "mimetype":
                        subquery = "((SELECT mtmessages1.pr as pr, " + \
                                "mtmessages1.mimetype as mimetype " + \
                                "FROM messages as mtmessages1) " + \
                                "UNION " + \
                                "(SELECT mtmessages2.pr as pr, " + \
                                "mtattach2.mimetype as mimetype " + \
                                "FROM messages as mtmessages2, " + \
                                "     attachments as mtattach2 " + \
                                "WHERE mtmessages2.id = mtattach2.msgid))"
                        self.addjoin("PRs")
                        self.addjoin(subquery, "mimetypes")
                        self.addwhere("PRs.id = mimetypes.pr")
                        return self.makepresent(field, "mimetypes.mimetype")

                # XXX: need view userstrings
                # select (id, username as name) from users
                # union select (id, realname as name) from users
                # (allow searching emails? ugh)
                if field == "originator" or field == "submitter":
                        self.addjoin("PRs")
                        self.addjoin("userstrings", "originators")
                        self.addwhere("PRs.originator = originators.id")
                        return self.makepresent(field, "originators.name")

                if field == "reporter" or field == "respondent":
                        self.addjoin("PRs")
                        self.addjoin("subscriptions")
                        self.addjoin("userstrings", "reporters")
                        self.addwhere("subscriptions.userid = reporters.id")
                        self.addwhere("subscriptions.reporter")
                        return self.makepresent(field, "reporters.name")

                if field == "responsible":
                        self.addjoin("PRs")
                        self.addjoin("subscriptions")
                        self.addjoin("userstrings", "responsibles")
                        self.addwhere("subscriptions.userid = responsibles.id")
                        self.addwhere("subscriptions.responsible")
                        return self.makepresent(field, "responsibles.name")

                if field in hierclasses:
                        col = field + "_data"
                        self.addjoin("PRs")
                        self.addjoin("hierclass_data", col)
                        self.addwhere("PRs.id = %s.pr" % col)
                        self.addwhere("%s.scheme = '%s'" % (col, field))
                        return self.makepresent(field, "%s.value" % col)

                if field in flatclasses:
                        col = field + "_data"
                        self.addjoin("PRs")
                        self.addjoin("flatclass_data", col)
                        self.addwhere("PRs.id = %s.pr" % col)
                        self.addwhere("%s.scheme = '%s'" % (col, field))
                        return self.makepresent(field, "%s.value" % col)

                if field in textclasses:
                        col = field + "_data"
                        self.addjoin("PRs")
                        self.addjoin("textclass_data", col)
                        self.addwhere("PRs.id = %s.pr" % col)
                        self.addwhere("%s.scheme = '%s'" % (col, field))
                        return self.makepresent(field, "%s.value" % col)

                if field in tagclasses:
                        col = field + "_data"
                        self.addjoin("PRs")
                        self.addjoin("tagclass_data", col)
                        self.addwhere("PRs.id = %s.pr" % col)
                        self.addwhere("%s.scheme = '%s'" % (col, field))
                        return self.makepresent(field, "%s.value" % col)

                sys.stderr.write("Unknown field %s" % field)
                exit(1)
        # end getfield

        # emit sql
        def build(self, sels):
                s = ", ".join(sels)
                f = ", ".join(self.fromitems)
                w = " and ".join(self.whereitems)
                q = "SELECT %s\nFROM %s\nWHERE %s\n" % (s, f, w)
                if self.order is not None:
                        q = q + "ORDER BY " + self.order + "\n"
                return q
        # endif

# end class QueryBuilder

# XXX we need to add dynamically:
#    hierclass_names.name to hierclasses[]
#    flatclass_names.name to flatclasses[]
#    textclass_names.name to textclasses[]
#    tagclass_names.name to tagclasses[]

############################################################
# database

dblink = None

def opendb(paranoid):
        global dblink

        host = "localhost"
        user = "swallowtail_public" if paranoid else "swallowtail_reader"
        database = "swallowtail"
        dblink = psycopg2.connect("host=%s user=%s dbname=%s" %
                                        (host, user, database))
# end opendb

def closedb():
        global dblink

        dblink.close()
        dblink = None
# end closedb

def querydb(qtext, args):
        print("Executing this query:")
        print(qtext)
        print("Args are:")
        print(args)

        cursor = dblink.cursor()
        cursor.execute(qtext, args)
        result = cursor.fetchall()
        cursor.close()
        return result
# end querydb

############################################################
# classification schemes

#
# Load the available classification schemes from the database.
# We only need their names up front.
#
def fetch_classifications():
        global hierclasses, flatclasses, textclasses, tagclasses
        hierclasses = querydb('''
------------------------------
        SELECT name FROM hierclass_names ORDER BY ordering;
------------------------------
''', [])
        flatclasses = querydb('''
------------------------------
        SELECT name FROM flatclass_names ORDER BY ordering;
------------------------------
''', [])
        textclasses = querydb('''
------------------------------
        SELECT name FROM textclass_names ORDER BY ordering;
------------------------------
''', [])
        tagclasses = querydb('''
------------------------------
        SELECT name FROM tagclass_names ORDER BY ordering;
------------------------------
''', [])
        # The results come back as monoples, unwrap that
        hierclasses = [name for (name,) in hierclasses]
        flatclasses = [name for (name,) in flatclasses]
        textclasses = [name for (name,) in textclasses]
        tagclasses = [name for (name,) in tagclasses]

############################################################
# query class for searches
# XXX: obsolete, remove

class Query:
        def __init__(self):
                self.selections = []
                self.tables = []
                self.constraints = []
                self.args = []
        prtables = ["PRs"]
        prconstraints = []

        def select(self, s):
                self.selections.append(s)

        def addtable(self, t):
                assert(t not in self.tables)
                self.tables.append(t)

        def constrain(self, expr):
                self.constraints.append(t)

        def internval(self, val):
                num = len(self.args)
                self.args[num] = val
                return "$%d" % num

        def textify(self):
                s = "SELECT %s\n" % ",".join(self.selections)
                f = "FROM %s\n" % ",".join(self.tables)
                w = "WHERE %s\n" % " AND ".join(self.constraints)
                return s + f + w
# end class Query

def regexp_constraint(q, field, value):
        cleanval = q.internval(value)
        if not isregexp(value):
                return "%s = %s" % (field, cleanval)
        else:
                # XXX what's the right operator again?
                return "%s ~= %s" % (field, cleanval)
# end regexp_constraint

def intrange_constraint(q, field, value):
        (lower, upper) = args.number
        if lower is not None:
                assert(typeof(lower) == int)
                prq.constrain("%s >= %d" % (field, lower))
        if upper is not None:
                assert(typeof(upper) == int)
                prq.constrain("%s <= %d" % (field, upper))
# end intrange_constraint

def daterange_constraint(q, field, value):
        # XXX
        assert(0)
# end daterange_constraint

############################################################

# this is old code that needs to be merged or deleted into the new stuff
def oldstuff():

        # If we're doing something other than a search, do it now
        if args.attach is not None:
                get_attachment(args.attach)
                exit(0)
        if args.message is not None:
                get_message(args.message)
                exit(0)

        if args.prs is not None and len(args.prs) > 0:
                show_prs(args.prs)
                exit(0)

        #
        # Collect up the search constraints
        #
        
        # 1. Constraints on the PRs table
        checkprtable = False
        prq = Query()
        prq.select("PRs.id as id")
        prq.addtable("PRs")
        if not args.closed:
                checkprtable = True
                prq.addtable("states")
                prq.constrain("PRs.state = states.name")
                prq.constrain("states.closed = FALSE")
        if args.public:
                checkprtable = True
                prq.constrain("NOT PRs.confidential")
        if args.number is not None:
                checkprtable = True
                intrange_constraint(prq, "PRs.id", args.number)
        if args.synopsis is not None:
                checkprtable = True
                regexp_constraint(prq, "PRs.synopsis", args.synopsis)
        if args.confidential is not None:
                checkprtable = True
                assert(typeof(args.confidential) == bool)
                if args.confidential:
                        prq.constrain("PRs.confidential")
                else:
                        prq.constrain("not PRs.confidential")
        if args.state is not None:
                checkprtable = True
                regexp_constraint(prq, "PRs.state", args.state)
        if args.locked is not None:
                checkprtable = True
                assert(typeof(args.locked) == bool)
                if args.locked:
                        prq.constrain("PRs.locked")
                else:
                        prq.constrain("not PRs.locked")
        if args.arrival_schemaversion is not None:
                checkprtable = True
                intrange_constraint(prq, "PRs.arrival_schemaversion",
                                        args.arrival_schemaversion)
        if args.arrival_date is not None:
                checkprtable = True
                daterange_constraint(prq, "PRs.arrival_date",
                                        args.arrival_date)
        if args.closed_date is not None:
                checkprtable = True
                daterange_constraint(prq, "PRs.closed_date",
                                        args.closed_date)
        if args.last_modified is not None:
                checkprtable = True
                daterange_constraint(prq, "PRs.last_modified",
                                        args.last_modified)
        if args.release is not None:
                checkprtable = True
                regexp_constraint(prq, "PRs.release", args.release)
        if args.environment is not None:
                checkprtable = True
                regexp_constraint(prq, "PRs.environment", args.environment)

        if args.originator_name is not None or \
                        args.originator_email is not None:
                prq.addtable("usermail as originator")
                prq.constrain("PRs.originator = originator.id")
        if args.originator_name is not None:
                checkprtable = True
                regexp_constraint(prq, "originator.realname",
                                        args.originator_name)
        if args.originator_email is not None:
                checkprtable = True
                regexp_constraint(prq, "originator.email",
                                        args.originator_name)
        if args.originator_id is not None:
                checkprtable = True
                intrange_constraint(prq, "PRs.originator", args.originator_id)

        queries = []
        if checkprtable:
                queries.append(prq)

        if args.responsible is not None:
                sq = Query()
                sq.select("subscriptions.pr as id")
                sq.addtable("subscriptions")
                sq.addtable("users")
                sq.constrain("subscriptions.userid = users.id")
                regexp_constraint(sq, "users.realname", args.responsible)
                sq.constrain("subscriptions.responsible")
                queries.append(sq)
        if args.respondent is not None:
                sq = Query()
                sq.select("subscriptions.pr as id")
                sq.addtable("subscriptions")
                sq.addtable("users as subscribed")
                sq.constrain("subscriptions.userid = users.id")
                regexp_constraint(sq, "users.realname", args.respondent)
                sq.constrain("subscriptions.reporter")
                queries.append(sq)
        if args.subscribed is not None:
                sq = Query()
                sq.select("subscriptions.pr as id")
                sq.addtable("subscriptions")
                sq.addtable("users as subscribed")
                sq.constrain("subscriptions.userid = users.id")
                regexp_constraint(sq, "users.realname", args.subscribed)
                queries.append(sq)

        if args.messages is not None:
                mq = Query()
                mq.select("messages.pr as id")
                mq.addtable("messages")
                regexp_constraint(sq, "messages.text", args.messages)
                queries.append(mq)

        if args.adminlog is not None:
                aq = Query()
                aq.select("adminlog.pr as id")
                aq.addtable("adminlog")
                regexp_constraint(sq, "adminlog.change", args.adminlog)
                regexp_constraint(sq, "adminlog.comment", args.adminlog)
                assert(len(aq.constraints) == 2)
                x = "%s OR %s" % (aq.constraints[0], aq.constraints[1])
                aq.constraints = [x]
                queries.append(aq)

        if args.anytext is not None:
                choke("--anytext isn't supported yet")

        for scheme in classification_schemes:   
                if args[scheme] is not None:
                        schemetype = classification_schemetypes[scheme]
                        tbl = "%sclass_data" % schemetype
                        cq = Query()
                        cq.select("scheme.pr as id")
                        cq.addtable("%s as scheme" % schemetype)
                        cq.constrain("scheme.scheme = '%s'" % scheme)
                        regexp_constraint(cq, "scheme.value", args[scheme])
                        queries.append(cq)
        # end loop

        querytexts = [q.textify() for q in queries]
        return "INTERSECT\n".join(querytexts)

############################################################
# printing

class PrintText:
        def __init__(self, output):
                self.lines = (output == "RAW" or output == "LIST")
        def printheader(self, row):
                # nothing
                pass
        def printrow(self, row):
                # XXX
                print(row)
        def printfooter(self, row):
                # nothing
                pass
# end class PrintText

class PrintCsv:
        def __init__(self, output):
                # nothing
                pass
        def printheader(self, row):
                # XXX
                pass
        def printrow(self, row):
                # XXX
                pass
        def printfooter(self, row):
                # nothing
                pass
# end class PrintCsv

class PrintXml:
        def __init__(self, output):
                # nothing
                pass
        def printheader(self, row):
                # XXX
                pass
        def printrow(self, row):
                # XXX
                pass
        def printfooter(self, row):
                # XXX
                pass
# end class PrintXml

class PrintJson:
        def __init__(self, output):
                # nothing
                pass
        def printheader(self, row):
                # XXX
                pass
        def printrow(self, row):
                # XXX
                pass
        def printfooter(self, row):
                # XXX
                pass
# end class PrintJson

class PrintRdf:
        def __init__(self, output):
                # nothing
                pass
        def printheader(self, row):
                # XXX
                pass
        def printrow(self, row):
                # XXX
                pass
        def printfooter(self, row):
                # XXX
                pass
# end class PrintRdf

class PrintRdflike:
        def __init__(self, output):
                # nothing
                pass
        def printheader(self, row):
                # XXX
                pass
        def printrow(self, row):
                # XXX
                pass
        def printfooter(self, row):
                # XXX
                pass
# end class PrintRdflike

def print_prs(ids):
        if sel.outformat == "TEXT":
                mkprinter = PrintText
        elif sel.outformat == "CSV":
                mkprinter = PrintCsv
        elif sel.outformat == "XML":
                mkprinter = PrintXml
        elif sel.outformat == "JSON":
                mkprinter = PrintJson
        elif sel.outformat == "RDF":
                mkprinter = PrintRdf
        elif sel.outformat == "RDFLIKE":
                mkprinter = PrintRdflike
        else:
                assert(False)

        # reset the printer
        printer = mkprinter(sel.output)

        if sel.output == "RAW":
                printer.printheader(ids[0])
                for id in ids:
                        printer(id)
                printer.printfooter(ids[0])
                return
        elif sel.output == "LIST":
                # XXX is there a clean way to do this passing the
                # whole list of ids at once?
                query = "SELECT id, synopsis\n" + \
                        "FROM PRs\n" + \
                        "WHERE id = $1"
        elif sel.output == "HEADERS":
                query = None # XXX
        elif sel.output == "META":
                query = None # XXX
        elif sel.output == "FULL":
                query = None # XXX
        else:
                assert(False)

        first = True
        for id in ids:
                results = querydb(query, [id])
                if first:
                        printer.printheader(results[0])
                        first = False
                for r in results:
                        printer.printrow(r)
        printer.printfooter(results[0])
# end print_prs

# XXX if in public mode we need to check if the PR is public
def print_message(pr, msgnum):
        query = "SELECT users.username AS username,\n" + \
                "       users.realname AS realname,\n" + \
                "       messages.id AS id, parent_id,\n" + \
                "       posttime, mimetype, body\n" + \
                "FROM messages, users\n" + \
                "WHERE messages.who = users.id\n" + \
                "  AND messages.pr = $1\n" + \
                "  AND messages.number_in_pr = $2\n"
        # Note that while pr is safe, msgnum came from the commandline
        # and may not be.
        results = querydb(query, [pr, msgnum])
        [result] = results
        (username, realname, id, parent_id, posttime, mimetype, body) = result
        # XXX honor mimetype
        # XXX honor output format (e.g. html)
        sys.stdout.write("From swallowtail@%s  %s\n" % (organization,posttime))
        sys.stdout.write("From: %s (%s)\n" % (username, realname))
        sys.stdout.write("References: %s\n" % parent_id)
        sys.stdout.write("Date: %s\n" % posttime)
        sys.stdout.write("Content-Type: %s\n" % mimetype)
        sys.stdout.write("\n")
        sys.stdout.write(body)
# end print_message

# XXX if in public mode we need to check if the PR is public
def print_attachment(pr, attachnum):
        query = "SELECT a.mimetype as mimetype, a.body as body\n" + \
                "FROM messages, attachments as a\n" + \
                "WHERE messages.pr = $1\n" + \
                "  AND messages.id = a.msgid\n" + \
                "  AND a.number_in_pr = $2\n"
        # Note that while pr is safe, attachnum came from the
        # commandline and may not be.
        results = querydb(query, [pr, msgnum])
        [result] = results
        (mimetype, body) = result
        # XXX honor mimetype
        # XXX need an http output mode so we can send the mimetype!
        sys.stdout.write(body)
# end print_attachment

############################################################
# AST for input query

class Invocation:
        Q_TERM = 1
        Q_SQL = 2
        class Query:
                def __init__(self, type):
                        self.type = type
                def dump(self, d):
                        if self.type == Invocation.Q_TERM:
                                d.write("query.term({})".format(self.term))
                        else:
                                d.write("query.sql({})".format(self.sql))
        def mkterm(term):
                self = Invocation.Query(Invocation.Q_TERM)
                self.term = term
                return self
        def mksql(s):
                self = Invocation.Query(Invocation.Q_SQL)
                self.sql = s
                return self

        class Order:
                def __init__(self, field, rev = False):
                        self.field = field
                        self.rev = rev
                def dump(self, d):
                        d.write("order({}, {})".format(self.field, self.rev))
        def mkoldest():
                return Invocation.Order("number")
        def mknewest():
                return Invocation.Order("number", True)
        def mkstaleness():
                return Invocation.Order("modified_date", True)
        def mkfield(field):
                return Invocation.Order(field)
        def mkrevfield(field):
                return Invocation.Order(field, True)

        class Search:
                def __init__(self, qs, openonly, publiconly, os):
                        self.queries = qs
                        self.openonly = openonly
                        self.publiconly = publiconly
                        self.orders = os
                def dump(self, d):
                        d.write("search({}, {})".format(
                                "open" if self.openonly else "closed",
                                "public" if self.publiconly else "privileged"))
                        d.indent()
                        d.write("queries")
                        d.indent()
                        for query in self.queries:
                                query.dump(d)
                        d.unindent()
                        d.write("orders")
                        d.indent()
                        for order in self.orders:
                                order.dump(d)
                        d.unindent()
                        d.unindent()

        S_PR = 1
        S_MESSAGE = 2
        S_ATTACHMENT = 3
        class Selection:
                def __init__(self, type):
                        self.type = type
                def dump(self, d):
                        if self.type == Invocation.S_PR:
                                d.write("selection.pr({}, {})".format(
                                        self.output, self.outformat))
                        elif self.type == Invocation.S_MESSAGE:
                                d.write("selection.message({})".format(
                                        self.message))
                        else:
                                d.write("selection.attachment({})".format(
                                        self.attachment))
        def mkpr(output, outformat):
                self = Invocation.Selection(Invocation.S_PR)
                self.output = output
                self.outformat = outformat
                return self
        def mkmessage(arg):
                self = Invocation.Selection(Invocation.S_MESSAGE)
                self.message = arg
                return self
        def mkattachment(arg):
                self = Invocation.Selection(Invocation.S_ATTACHMENT)
                self.attachment = arg
                return self

        OP_FIELDS = 1
        OP_SHOW = 2
        OP_RANGE = 3
        OP_SEARCH = 4
        class Op:
                def __init__(self, type):
                        self.type = type
                def dump(self, d):
                        if self.type == Invocation.OP_FIELDS:
                                d.write("op.fields")
                        elif self.type == Invocation.OP_SHOW:
                                d.write("op.show({})".format(self.field))
                        elif self.type == Invocation.OP_RANGE:
                                d.write("op.range({})".format(self.field))
                        else:
                                d.write("op.search:")
                                d.indent()
                                self.search.dump(d)
                                for sel in self.sels:
                                        sel.dump(d)
                                d.unindent()
        def mkfields():
                return Invocation.Op(Invocation.OP_FIELDS)
        def mkshow(field):
                self = Invocation.Op(Invocation.OP_SHOW)
                self.field = field
                return self
        def mkrange(field):
                self = Invocation.Op(Invocation.OP_RANGE)
                self.field = field
                return self
        def mksearch(s, sels):
                self = Invocation.Op(Invocation.OP_SEARCH)
                self.search = s
                self.sels = sels
                return self

        def __init__(self, ops):
                self.ops = ops
        def dump(self, d):
                d.write("invocation: {} ops".format(len(self.ops)))
                d.indent()
                for op in self.ops:
                        op.dump(d)
                d.unindent()
# end class Invocation

############################################################
# run (eval the SQL and print the results)

def run_sel(sel, ids):
        if sel.type == S_PR:
                if ids == []:
                        sys.stderr.write("No PRs matched.\n")
                        exit(1)

                print_prs(ids)
        elif sel.type == S_MESSAGE:
                if len(ids) != 1:
                        sys.stderr.write("Cannot retrieve messages " +
                                "from multiple PRs.")
                        exit(1)
                print_message(ids[0], sel.message)
        elif sel.type == S_ATTACHMENT:
                if len(ids) != 1:
                        sys.stderr.write("Cannot retrieve attachments " +
                                "from multiple PRs.")
                        exit(1)
                print_message(ids[0], sel.attachment)
        else:
                assert(False)

def run_op(op):
        if op.type == OP_FIELDS:
                list_fields()
        elif op.type == OP_SHOW:
                describe_field(op.field)
        elif op.type == OP_RANGE:
                print_field_range(op.field)
        elif op.type == OP_SEARCH:
                sql = op.search
                args = op.args # XXX not there!
                ids = querydb(op.search, args)
                for s in op.sels:
                        run_sel(s, ids)
        else:
                assert(False)

def run(ast):
        for op in ast.ops:
                run_op(op)

############################################################
# compile (convert the AST so the searches are pure SQL)

#
# XXX this doesn't work, we need to keep the interned strings
# on return from compile_query.
#

def matches(s, rx):
        # XXX
        return True

def compile_query(q):
        if q.type == Q_QSTRING:
                # XXX should use a split that honors quotes
                terms = q.string.split()
                terms = [dotstring(t) for t in terms]
                return compile_query(doand(terms))
        if q.type == Q_TSTRING:
                qb = QueryBuilder()
                s = q.string
                if matches(s, "^[0-9]+$"):
                        f = qb.getfield("number")
                        # Note: s is user-supplied but clean to insert directly
                        qb.addwhere("%s = %s" % (f, s))
                elif matches(s, "^[0-9]+-[0-9]+$"):
                        f = qb.getfield("number")
                        ss = s.split("-")
                        # Note: ss[] is user-supplied but clean
                        qb.addwhere("%s >= %s" % (f, ss[0]))
                        qb.addwhere("%s <= %s" % (f, ss[1]))
                elif matches(s, "^[0-9]+-$"):
                        f = qb.getfield("number")
                        ss = s.split("-")
                        # Note: ss[] is user-supplied but clean
                        qb.addwhere("%s >= %s" % (f, ss[0]))
                elif matches(s, "^-[0-9]+$"):
                        f = qb.getfield("number")
                        ss = s.split("-")
                        # Note: ss[] is user-supplied but clean
                        qb.addwhere("%s <= %s" % (f, ss[1]))
                elif matches(s, "^[^:]+:[^:]+$"):
                        # XXX honor quoted terms
                        # XXX = or LIKE?
                        ss = s.split(":")
                        # ss[0] is not clean but if it's crap it won't match
                        f = qb.getfield(ss[0])
                        # ss[1] is not clean, so intern it for safety
                        s = qb.intern(ss[1])
                        qb.addwhere("%s = %s" % (f, s))
                elif matches(s, "^-[^:]+:[^:]+$"):
                        # XXX honor quoted terms
                        # XXX <> or NOT LIKE?
                        ss = s.split(":")
                        # ss[0] is not clean but if it's crap it won't match
                        f = qb.getfield(ss[0])
                        # ss[1] is not clean, so intern it for safety
                        s = qb.intern(ss[1])
                        qb.addwhere("%s <> %s" % (f, s))
                elif matches(s, "^-"):
                        # XXX <> or NOT LIKE?
                        f = qb.getfield("alltext")
                        # s is not clean, so intern it for safety
                        s = qb.intern(s)
                        qb.addwhere("%s <> %s" % (f, s))
                else:
                        # XXX = or LIKE?
                        f = qb.getfield("alltext")
                        # s is not clean, so intern it for safety
                        s = qb.intern(s)
                        qb.addwhere("%s = %s" % (f, s))

                # XXX also does not handle:
                #
                # field: with no string (supposed to use a default
                # search string)
                #
                # generated search fields that parse dates:
                # {arrived,closed,modified,etc.}-{before,after}:date
                #
                # stale:time

                return qb.build("PRs.id")
        # end Q_TSTRING case
        if q.type == Q_OR:
                subqueries = ["(" + compile_query(sq) + ")" for sq in q.args]
                return " UNION ".join(subqueries)
        if q.type == Q_AND:
                subqueries = ["(" + compile_query(sq) + ")" for sq in q.args]
                return " INTERSECT ".join(subqueries)
        if q.type == Q_SQL:
                return q.sql
        assert(False)
# end compile_query

def compile_order(qb, o):
        str = qb.getfield(o.field)
        if o.rev:
                str = str + " DESCENDING"
        return str

def compile_search(s):
        qb2 = QueryBuilder()

        # multiple query strings are treated as OR
        query = door(s.queries)
        query = compile_query(q)

        if s.openonly:
                qb2.addwhere("not %s" % qb.getfield("closed"))
        if s.publiconly:
                qb2.addwhere("not %s" % qb.getfield("confidential"))

        orders = [compile_order(qb2, o) for o in s.orders]
        order = ", ".join(orders)
        if order != "":
                qb2.setorder(order)

        if qb2.nonempty():
                qb2.addjoin(query, "search")
                qb2.addjoin("PRs")
                qb2.addwhere("search = PRs.id")
                query = qb2.build(["search"])

        return query
# end compile_search

def compile_op(op):
        if op.type == OP_SEARCH:
                op.search = compile_search(op.search)
        return op

def compile(ast):
        ast.ops = [compile_op(op) for op in ast.ops]

############################################################
# arg handling

#
# I swear, all getopt interfaces suck. You have to write your own to
# not go mad.
#

# Usage.
def usage():
        sys.stderr.write("""
query-pr: search for and retrieve problem reports
usage: query-pr [options] [searchterms] Query the database.
       query-pr [options] --sql QUERY   Execute QUERY as the search.
       query-pr [options] -s QUERY      Same as --sql.
       query-pr --fields                List database fields.
       query-pr --show FIELD            Print info about database field.
       query-pr --range FIELD           Print extant range for database field.
       query-pr --help / -h             Print this message.
       query-pr --version / -v          Print version and exit.
options:
       --search-string STRING           Forcibly treat STRING as a search term.
       --message NUM / -m NUM           Print a message by its ID number.
       --attachment NUM / -a NUM        Print an attachment by its ID number.
       --paranoid                       Deny unsafe settings.
filter options:
       --open                           Exclude closed PRs. (default)
       --closed                         Include closed PRs.
       --public                         Exclude confidential PRs.
       --privileged                     Include confidential PRs. (default)
sort options:
       --oldest                         Sort with oldest PRs first. (default)
       --newest                         Sort with newest PRs first.
       --staleness                      Sort by last modification time.
       --orderby FIELD                  Sort by specific field.
output options:
       --raw / -r                       Print raw SQL output.
       --list / -l                      Print in list form.
       --headers                        Print headers only.
       --meta / --metadata              Print all metadata.
       --full / -f                      Print entire PR.
       --text                           Print text. (default)
       --csv                            Print CSV.
       --xml                            Print XML.
       --json                           Print JSON.
       --rdf                            Print RDF.
       --rdflike                        Print RDF-like text.
search terms:
       NUM                              Single PR by number.
       NUM-[NUM]                        Range of PRs by number.
       TEXT                             Search string
       FIELD:TEXT                       Search string for a particular field.
       FIELD:                           Use the field's default search string.
derived fields:
       arrived-before:DATE              Arrival date before DATE.
       arrived-after:DATE               Arrival date after DATE.
       closed-before:DATE               Close date before DATE.
       closed-after:DATE                Close date after DATE, or none.
       last-modified-before:DATE        Last modified before DATE.
       last-modified-after:DATE         Last modified after DATE.
       stale:TIME                       Last modified at least TIME ago.
Explicit SQL queries should return lists of PR numbers (only).
""")

# Widget to hold argv and allow peeling args off one at a time.
class ArgHolder:
        def __init__(self, argv):
                self.argc = len(argv)
                self.argv = argv
                self.pos = 1
        def next(self):
                if self.pos >= self.argc:
                        return None
                ret = self.argv[self.pos]
                self.pos += 1
                return ret
        def getarg(self, opt):
                ret = self.next()
                if ret is None:
                        msg = "Option {} requires an argument\n".format(opt)
                        sys.stderr.write(msg)
                        exit(1)
                return ret

# Read the argument list and convert it into an Invocation.
def getargs(argv):
        # Results
        ops = []
        orders = []
        queries = []
        selections = []
        output = "LIST"
        outformat = "TEXT"
        openonly = True
        publiconly = False
        paranoid = False
        nomoreoptions = False

        #
        # Functions for the options
        #

        def do_paranoid():
                nonlocal paranoid, publiconly
                paranoid = True
                publiconly = True

        def do_help():
                usage()
                exit(0)
        def do_version():
                msg = "query-pr {}\n".format(program_version)
                sys.stdout.write(msg)
                exit(0)
        def do_fields():
                ops.append(Invocation.mkfields())
        def do_show(field):
                ops.append(Invocation.mkshow(field))
        def do_range(field):
                ops.append(Invocation.mkrange(field))

        def do_search(term):
                queries.append(Invocation.mkterm(term))
        def do_sql(text):
                assert(not paranoid)
                queries.append(Invocation.mksql(text))

        def do_open():
                nonlocal openonly
                openonly = True
        def do_closed():
                nonlocal openonly
                openonly = False
        def do_public():
                nonlocal publiconly
                publiconly = True
        def do_privileged():
                nonlocal publiconly
                assert(not paranoid)
                publiconly = False

        def do_oldest():
                orders.append(Invocation.mkoldest())
        def do_newest():
                orders.append(Invocation.mknewest())
        def do_staleness():
                orders.append(Invocation.mkstaleness())
        def do_orderby(field):
                orders.append(Invocation.mkfield(field))
        def do_revorderby(field):
                orders.append(Invocation.mkrevfield(field))

        def do_message(n):
                selections.append(Invocation.mkmessage(n))
        def do_attachment(n):
                selections.append(Invocation.mkattachment(n))

        def do_raw():
                nonlocal output
                output = "RAW"
        def do_list():
                nonlocal output
                output = "LIST"
        def do_headers():
                nonlocal output
                output = "HEADERS"
        def do_meta():
                nonlocal output
                output = "META"
        def do_full():
                nonlocal output
                output = "FULL"

        def do_text():
                nonlocal outformat
                outformat = "TEXT"
        def do_csv():
                nonlocal outformat
                outformat = "CSV"
        def do_xml():
                nonlocal outformat
                outformat = "XML"
        def do_json():
                nonlocal outformat
                outformat = "JSON"
        def do_rdf():
                nonlocal outformat
                outformat = "RDF"
        def do_rdflike():
                nonlocal outformat
                outformat = "RDFLIKE"

        def do_unknown(opt):
                sys.stderr.write("Unknown option {}\n".format(opt))
                exit(1)

        args = ArgHolder(argv)
        while True:
                arg = args.next()
                if arg is None:
                        break

                if nomoreoptions or arg[0] != "-":
                        do_search(arg)
                elif arg == "--":
                        nomoreoptions = True
                # Long options
                elif arg == "--attachment":
                        do_attachment(args.getarg(arg))
                elif arg.startswith("--attachment="):
                        do_message(arg[13:])
                elif arg == "--closed":
                        do_closed()
                elif arg == "--csv":
                        do_csv()
                elif arg == "--fields":
                        do_fields()
                elif arg == "--full":
                        do_full()
                elif arg == "--headers":
                        do_headers()
                elif arg == "--help":
                        do_help()
                elif arg == "--json":
                        do_json()
                elif arg == "--list":
                        do_list()
                elif arg == "--message":
                        do_message(args.getarg(arg))
                elif arg.startswith("--message="):
                        do_message(arg[10:])
                elif arg == "--meta":
                        do_meta()
                elif arg == "--metadata":
                        do_meta()
                elif arg == "--newest":
                        do_newest()
                elif arg == "--oldest":
                        do_oldest()
                elif arg == "--orderby":
                        do_orderby(args.getarg(arg))
                elif arg.startswith("--orderby="):
                        do_orderby(arg[10:])
                elif arg == "--open":
                        do_open()
                elif arg == "--paranoid":
                        do_paranoid()
                elif arg == "--public":
                        do_public()
                elif arg == "--privileged" and not paranoid:
                        do_privileged()
                elif arg == "--range":
                        do_range(args.getarg(arg))
                elif arg.startswith("--range="):
                        do_range(arg[8:])
                elif arg == "--raw":
                        do_raw()
                elif arg == "--rdf":
                        do_rdf()
                elif arg == "--rdflike":
                        do_rdflike()
                elif arg == "--revorderby":
                        do_revorderby(args.getarg(arg))
                elif arg.startswith("--revorderby="):
                        do_revorderby(arg[13:])
                elif arg == "--search":
                        do_search(args.getarg(arg))
                elif arg.startswith("--search="):
                        do_search(arg[9:])
                elif arg == "--show":
                        do_show(args.getarg(arg))
                elif arg.startswith("--show="):
                        do_show(arg[7:])
                elif arg == "--sql" and not paranoid:
                        do_sql(args.getarg(arg))
                elif arg.startswith("--sql=") and not paranoid:
                        do_sql(arg[7:])
                elif arg == "--staleness":
                        do_staleness()
                elif arg == "--text":
                        do_text()
                elif arg == "--version":
                        do_version()
                elif arg == "--xml":
                        do_xml()
                elif arg.startswith("--"):
                        do_unknown(arg)
                else:
                        # short options
                        i = 1
                        n = len(arg)
                        while i < n:
                                opt = arg[i]
                                i += 1
                                def getarg():
                                        nonlocal i
                                        if i < n:
                                                ret = arg[i:]
                                        else:
                                                ret = args.getarg("-" + opt)
                                        i = n
                                        return ret

                                if opt == "a":
                                        do_attachment(getarg())
                                elif opt == "f":
                                        do_full()
                                elif opt == "h":
                                        do_help()
                                elif opt == "l":
                                        do_list()
                                elif opt == "m":
                                        do_message(getarg())
                                elif opt == "s" and not paranoid:
                                        do_sql(getarg())
                                elif opt == "r":
                                        do_raw()
                                elif opt == "v":
                                        do_version()
                                else:
                                        do_unknown("-" + opt)

        # Now convert what we got to a single thing.
        if queries != []:
                if orders == []:
                        orders = [Invocation.mkoldest()]
                if selections == []:
                        selections = [Invocation.mkpr(output, outformat)]
                search = Invocation.Search(queries, openonly, publiconly, orders)
                ops.append(Invocation.mksearch(search, selections))
        else:
                if orders != []:
                        msg = "No queries given for requested orderings\n"
                        sys.stderr.write(msg)
                        exit(1)
                if selections != []:
                        msg = "No queries given for requested selections\n"
                        sys.stderr.write(msg)
                        exit(1)

        return (Invocation(ops), paranoid)
# end getargs

############################################################
# main

(todo, paranoid) = getargs(sys.argv)
#todo.dump(Dumper(sys.stdout))

opendb(paranoid)
fetch_classifications()
todo = compile(todo)
run(todo)
closedb()
exit(0)