Repair insufficiently careful type checking for SQL-language functions:
authorTom Lane <[email protected]>
Fri, 2 Feb 2007 00:03:44 +0000 (00:03 +0000)
committerTom Lane <[email protected]>
Fri, 2 Feb 2007 00:03:44 +0000 (00:03 +0000)
we should check that the function code returns the claimed result datatype
every time we parse the function for execution.  Formerly, for simple
scalar result types we assumed the creation-time check was sufficient, but
this fails if the function selects from a table that's been redefined since
then, and even more obviously fails if check_function_bodies had been OFF.

This is a significant security hole: not only can one trivially crash the
backend, but with appropriate misuse of pass-by-reference datatypes it is
possible to read out arbitrary locations in the server process's memory,
which could allow retrieving database content the user should not be able
to see.  Our thanks to Jeff Trout for the initial report.

Security: CVE-2007-0555

src/backend/executor/functions.c
src/backend/optimizer/util/clauses.c

index 230e899f4a5e9a557dcdf2d451f44c0cc5918f80..5203671f8044faf2a439ecaa40e1d97ad7454701 100644 (file)
@@ -62,7 +62,7 @@ typedef struct
 {
        Oid                *argtypes;           /* resolved types of arguments */
        Oid                     rettype;                /* actual return type */
-       int                     typlen;                 /* length of the return type */
+       int16           typlen;                 /* length of the return type */
        bool            typbyval;               /* true if return type is pass by value */
        bool            returnsTuple;   /* true if returning whole tuple result */
        bool            shutdown_reg;   /* true if registered shutdown callback */
@@ -152,12 +152,9 @@ init_sql_fcache(FmgrInfo *finfo)
        Oid                     foid = finfo->fn_oid;
        Oid                     rettype;
        HeapTuple       procedureTuple;
-       HeapTuple       typeTuple;
        Form_pg_proc procedureStruct;
-       Form_pg_type typeStruct;
        SQLFunctionCachePtr fcache;
        Oid                *argOidVect;
-       bool            haspolyarg;
        char       *src;
        int                     nargs;
        List       *queryTree_list;
@@ -194,35 +191,17 @@ init_sql_fcache(FmgrInfo *finfo)
 
        fcache->rettype = rettype;
 
+       /* Fetch the typlen and byval info for the result type */
+       get_typlenbyval(rettype, &fcache->typlen, &fcache->typbyval);
+
        /* Remember if function is STABLE/IMMUTABLE */
        fcache->readonly_func =
                (procedureStruct->provolatile != PROVOLATILE_VOLATILE);
 
-       /* Now look up the actual result type */
-       typeTuple = SearchSysCache(TYPEOID,
-                                                          ObjectIdGetDatum(rettype),
-                                                          0, 0, 0);
-       if (!HeapTupleIsValid(typeTuple))
-               elog(ERROR, "cache lookup failed for type %u", rettype);
-       typeStruct = (Form_pg_type) GETSTRUCT(typeTuple);
-
-       /*
-        * get the type length and by-value flag from the type tuple; also do
-        * a preliminary check for returnsTuple (this may prove inaccurate,
-        * see below).
-        */
-       fcache->typlen = typeStruct->typlen;
-       fcache->typbyval = typeStruct->typbyval;
-       fcache->returnsTuple = (typeStruct->typtype == 'c' ||
-                                                       rettype == RECORDOID);
-
        /*
-        * Parse and rewrite the queries.  We need the argument type info to
-        * pass to the parser.
+        * We need the actual argument types to pass to the parser.
         */
        nargs = procedureStruct->pronargs;
-       haspolyarg = false;
-
        if (nargs > 0)
        {
                int                     argnum;
@@ -245,7 +224,6 @@ init_sql_fcache(FmgrInfo *finfo)
                                                         errmsg("could not determine actual type of argument declared %s",
                                                                        format_type_be(argOidVect[argnum]))));
                                argOidVect[argnum] = argtype;
-                               haspolyarg = true;
                        }
                }
        }
@@ -253,6 +231,9 @@ init_sql_fcache(FmgrInfo *finfo)
                argOidVect = NULL;
        fcache->argtypes = argOidVect;
 
+       /*
+        * Parse and rewrite the queries in the function text.
+        */
        tmp = SysCacheGetAttr(PROCOID,
                                                  procedureTuple,
                                                  Anum_pg_proc_prosrc,
@@ -264,24 +245,25 @@ init_sql_fcache(FmgrInfo *finfo)
        queryTree_list = pg_parse_and_rewrite(src, argOidVect, nargs);
 
        /*
-        * If the function has any arguments declared as polymorphic types,
-        * then it wasn't type-checked at definition time; must do so now.
+        * Check that the function returns the type it claims to.  Although
+        * in simple cases this was already done when the function was defined,
+        * we have to recheck because database objects used in the function's
+        * queries might have changed type.  We'd have to do it anyway if the
+        * function had any polymorphic arguments.
         *
-        * Also, force a type-check if the declared return type is a rowtype; we
-        * need to find out whether we are actually returning the whole tuple
-        * result, or just regurgitating a rowtype expression result. In the
+        * Note: we set fcache->returnsTuple according to whether we are
+        * returning the whole tuple result or just a single column.  In the
         * latter case we clear returnsTuple because we need not act different
-        * from the scalar result case.
+        * from the scalar result case, even if it's a rowtype column.
         *
-        * In the returnsTuple case, check_sql_fn_retval will also construct
-        * JunkFilter we can use to coerce the returned rowtype to the desired
+        * In the returnsTuple case, check_sql_fn_retval will also construct a
+        * JunkFilter we can use to coerce the returned rowtype to the desired
         * form.
         */
-       if (haspolyarg || fcache->returnsTuple)
-               fcache->returnsTuple = check_sql_fn_retval(rettype,
-                                                                                                  get_typtype(rettype),
-                                                                                                  queryTree_list,
-                                                                                                  &fcache->junkFilter);
+       fcache->returnsTuple = check_sql_fn_retval(rettype,
+                                                                                          get_typtype(rettype),
+                                                                                          queryTree_list,
+                                                                                          &fcache->junkFilter);
 
        /* Finally, plan the queries */
        fcache->func_state = init_execution_state(queryTree_list,
@@ -289,7 +271,6 @@ init_sql_fcache(FmgrInfo *finfo)
 
        pfree(src);
 
-       ReleaseSysCache(typeTuple);
        ReleaseSysCache(procedureTuple);
 
        finfo->fn_extra = (void *) fcache;
@@ -862,11 +843,10 @@ ShutdownSQLFunction(Datum arg)
  * the final query in the function.  We do some ad-hoc type checking here
  * to be sure that the user is returning the type he claims.
  *
- * This is normally applied during function definition, but in the case
- * of a function with polymorphic arguments, we instead apply it during
- * function execution startup. The rettype is then the actual resolved
- * output type of the function, rather than the declared type. (Therefore,
- * we should never see ANYARRAY or ANYELEMENT as rettype.)
+ * For a polymorphic function the passed rettype must be the actual resolved
+ * output type of the function; we should never see ANYARRAY or ANYELEMENT
+ * as rettype.  (This means we can't check the type during function definition
+ * of a polymorphic function.)
  *
  * The return value is true if the function returns the entire tuple result
  * of its final SELECT, and false otherwise.  Note that because we allow
index 336140f361170d82a6408d96044617358d1dded5..5a2d3bc1356d6a28b7a473354c56759dbca915f9 100644 (file)
@@ -19,6 +19,7 @@
 
 #include "postgres.h"
 
+#include "catalog/pg_aggregate.h"
 #include "catalog/pg_language.h"
 #include "catalog/pg_proc.h"
 #include "catalog/pg_type.h"
@@ -33,6 +34,7 @@
 #include "optimizer/var.h"
 #include "parser/analyze.h"
 #include "parser/parse_clause.h"
+#include "parser/parse_coerce.h"
 #include "parser/parse_expr.h"
 #include "tcop/tcopprot.h"
 #include "utils/acl.h"
@@ -47,6 +49,7 @@
 typedef struct
 {
        List       *active_fns;
+       Node       *case_val;
        bool            estimate;
 } eval_const_expressions_context;
 
@@ -58,8 +61,7 @@ typedef struct
 } substitute_actual_parameters_context;
 
 static bool contain_agg_clause_walker(Node *node, void *context);
-static bool contain_distinct_agg_clause_walker(Node *node, void *context);
-static bool count_agg_clause_walker(Node *node, int *count);
+static bool count_agg_clauses_walker(Node *node, AggClauseCounts *counts);
 static bool expression_returns_set_walker(Node *node, void *context);
 static bool contain_subplans_walker(Node *node, void *context);
 static bool contain_mutable_functions_walker(Node *node, void *context);
@@ -358,71 +360,108 @@ contain_agg_clause_walker(Node *node, void *context)
 }
 
 /*
- * contain_distinct_agg_clause
- *       Recursively search for DISTINCT Aggref nodes within a clause.
+ * count_agg_clauses
+ *       Recursively count the Aggref nodes in an expression tree.
+ *
+ *       Note: this also checks for nested aggregates, which are an error.
  *
- *       Returns true if any DISTINCT aggregate found.
+ * We not only count the nodes, but attempt to estimate the total space
+ * needed for their transition state values if all are evaluated in parallel
+ * (as would be done in a HashAgg plan).  See AggClauseCounts for the exact
+ * set of statistics returned.
+ *
+ * NOTE that the counts are ADDED to those already in *counts ... so the
+ * caller is responsible for zeroing the struct initially.
  *
  * This does not descend into subqueries, and so should be used only after
  * reduction of sublinks to subplans, or in contexts where it's known there
  * are no subqueries.  There mustn't be outer-aggregate references either.
  */
-bool
-contain_distinct_agg_clause(Node *clause)
+void
+count_agg_clauses(Node *clause, AggClauseCounts *counts)
 {
-       return contain_distinct_agg_clause_walker(clause, NULL);
+       /* no setup needed */
+       count_agg_clauses_walker(clause, counts);
 }
 
 static bool
-contain_distinct_agg_clause_walker(Node *node, void *context)
+count_agg_clauses_walker(Node *node, AggClauseCounts *counts)
 {
        if (node == NULL)
                return false;
        if (IsA(node, Aggref))
        {
-               Assert(((Aggref *) node)->agglevelsup == 0);
-               if (((Aggref *) node)->aggdistinct)
-                       return true;            /* abort the tree traversal and return
-                                                                * true */
-       }
-       Assert(!IsA(node, SubLink));
-       return expression_tree_walker(node, contain_distinct_agg_clause_walker, context);
-}
+               Aggref     *aggref = (Aggref *) node;
+               Oid                     inputType;
+               HeapTuple       aggTuple;
+               Form_pg_aggregate aggform;
+               Oid                     aggtranstype;
+
+               Assert(aggref->agglevelsup == 0);
+               counts->numAggs++;
+               if (aggref->aggdistinct)
+                       counts->numDistinctAggs++;
+
+               inputType = exprType((Node *) aggref->target);
+
+               /* fetch aggregate transition datatype from pg_aggregate */
+               aggTuple = SearchSysCache(AGGFNOID,
+                                                                 ObjectIdGetDatum(aggref->aggfnoid),
+                                                                 0, 0, 0);
+               if (!HeapTupleIsValid(aggTuple))
+                       elog(ERROR, "cache lookup failed for aggregate %u",
+                                aggref->aggfnoid);
+               aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
+               aggtranstype = aggform->aggtranstype;
+               ReleaseSysCache(aggTuple);
+
+               /* resolve actual type of transition state, if polymorphic */
+               if (aggtranstype == ANYARRAYOID || aggtranstype == ANYELEMENTOID)
+               {
+                       /* have to fetch the agg's declared input type... */
+                       Oid                     agg_arg_types[FUNC_MAX_ARGS];
+                       int                     agg_nargs;
+
+                       (void) get_func_signature(aggref->aggfnoid,
+                                                                         agg_arg_types, &agg_nargs);
+                       Assert(agg_nargs == 1);
+                       aggtranstype = resolve_generic_type(aggtranstype,
+                                                                                               inputType,
+                                                                                               agg_arg_types[0]);
+               }
 
-/*
- * count_agg_clause
- *       Recursively count the Aggref nodes in an expression tree.
- *
- *       Note: this also checks for nested aggregates, which are an error.
- *
- * This does not descend into subqueries, and so should be used only after
- * reduction of sublinks to subplans, or in contexts where it's known there
- * are no subqueries.  There mustn't be outer-aggregate references either.
- */
-int
-count_agg_clause(Node *clause)
-{
-       int                     result = 0;
+               /*
+                * If the transition type is pass-by-value then it doesn't add
+                * anything to the required size of the hashtable.  If it is
+                * pass-by-reference then we have to add the estimated size of
+                * the value itself, plus palloc overhead.
+                */
+               if (!get_typbyval(aggtranstype))
+               {
+                       int32           aggtranstypmod;
+                       int32           avgwidth;
 
-       count_agg_clause_walker(clause, &result);
-       return result;
-}
+                       /*
+                        * If transition state is of same type as input, assume it's the
+                        * same typmod (same width) as well.  This works for cases like
+                        * MAX/MIN and is probably somewhat reasonable otherwise.
+                        */
+                       if (aggtranstype == inputType)
+                               aggtranstypmod = exprTypmod((Node *) aggref->target);
+                       else
+                               aggtranstypmod = -1;
 
-static bool
-count_agg_clause_walker(Node *node, int *count)
-{
-       if (node == NULL)
-               return false;
-       if (IsA(node, Aggref))
-       {
-               Assert(((Aggref *) node)->agglevelsup == 0);
-               (*count)++;
+                       avgwidth = get_typavgwidth(aggtranstype, aggtranstypmod);
+                       avgwidth = MAXALIGN(avgwidth);
+
+                       counts->transitionSpace += avgwidth + 2 * sizeof(void *);
+               }
 
                /*
                 * Complain if the aggregate's argument contains any aggregates;
                 * nested agg functions are semantically nonsensical.
                 */
-               if (contain_agg_clause((Node *) ((Aggref *) node)->target))
+               if (contain_agg_clause((Node *) aggref->target))
                        ereport(ERROR,
                                        (errcode(ERRCODE_GROUPING_ERROR),
                                  errmsg("aggregate function calls may not be nested")));
@@ -433,8 +472,8 @@ count_agg_clause_walker(Node *node, int *count)
                return false;
        }
        Assert(!IsA(node, SubLink));
-       return expression_tree_walker(node, count_agg_clause_walker,
-                                                                 (void *) count);
+       return expression_tree_walker(node, count_agg_clauses_walker,
+                                                                 (void *) counts);
 }
 
 
@@ -1157,6 +1196,7 @@ eval_const_expressions(Node *node)
        eval_const_expressions_context context;
 
        context.active_fns = NIL;       /* nothing being recursively simplified */
+       context.case_val = NULL;        /* no CASE being examined */
        context.estimate = false;       /* safe transformations only */
        return eval_const_expressions_mutator(node, &context);
 }
@@ -1181,6 +1221,7 @@ estimate_expression_value(Node *node)
        eval_const_expressions_context context;
 
        context.active_fns = NIL;       /* nothing being recursively simplified */
+       context.case_val = NULL;        /* no CASE being examined */
        context.estimate = true;        /* unsafe transformations OK */
        return eval_const_expressions_mutator(node, &context);
 }
@@ -1554,71 +1595,98 @@ eval_const_expressions_mutator(Node *node,
                 * If there are no non-FALSE alternatives, we simplify the entire
                 * CASE to the default result (ELSE result).
                 *
-                * If we have a simple-form CASE with constant test expression and
-                * one or more constant comparison expressions, we could run the
-                * implied comparisons and potentially reduce those arms to constants.
-                * This is not yet implemented, however.  At present, the
-                * CaseTestExpr placeholder will always act as a non-constant node
-                * and prevent the comparison boolean expressions from being reduced
-                * to Const nodes.
+                * If we have a simple-form CASE with constant test expression,
+                * we substitute the constant value for contained CaseTestExpr
+                * placeholder nodes, so that we have the opportunity to reduce
+                * constant test conditions.  For example this allows
+                *              CASE 0 WHEN 0 THEN 1 ELSE 1/0 END
+                * to reduce to 1 rather than drawing a divide-by-0 error.
                 *----------
                 */
                CaseExpr   *caseexpr = (CaseExpr *) node;
                CaseExpr   *newcase;
+               Node       *save_case_val;
                Node       *newarg;
                List       *newargs;
-               Node       *defresult;
-               Const      *const_input;
+               bool            const_true_cond;
+               Node       *defresult = NULL;
                ListCell   *arg;
 
                /* Simplify the test expression, if any */
                newarg = eval_const_expressions_mutator((Node *) caseexpr->arg,
                                                                                                context);
 
+               /* Set up for contained CaseTestExpr nodes */
+               save_case_val = context->case_val;
+               if (newarg && IsA(newarg, Const))
+                       context->case_val = newarg;
+               else
+                       context->case_val = NULL;
+
                /* Simplify the WHEN clauses */
                newargs = NIL;
+               const_true_cond = false;
                foreach(arg, caseexpr->args)
                {
-                       /* Simplify this alternative's condition and result */
-                       CaseWhen   *casewhen = (CaseWhen *)
-                       expression_tree_mutator((Node *) lfirst(arg),
-                                                                       eval_const_expressions_mutator,
-                                                                       (void *) context);
-
-                       Assert(IsA(casewhen, CaseWhen));
-                       if (casewhen->expr == NULL ||
-                               !IsA(casewhen->expr, Const))
-                       {
-                               newargs = lappend(newargs, casewhen);
-                               continue;
-                       }
-                       const_input = (Const *) casewhen->expr;
-                       if (const_input->constisnull ||
-                               !DatumGetBool(const_input->constvalue))
-                               continue;               /* drop alternative with FALSE condition */
+                       CaseWhen   *oldcasewhen = (CaseWhen *) lfirst(arg);
+                       Node       *casecond;
+                       Node       *caseresult;
+
+                       Assert(IsA(oldcasewhen, CaseWhen));
+
+                       /* Simplify this alternative's test condition */
+                       casecond =
+                               eval_const_expressions_mutator((Node *) oldcasewhen->expr,
+                                                                                          context);
 
                        /*
-                        * Found a TRUE condition.      If it's the first (un-dropped)
-                        * alternative, the CASE reduces to just this alternative.
+                        * If the test condition is constant FALSE (or NULL), then drop
+                        * this WHEN clause completely, without processing the result.
                         */
-                       if (newargs == NIL)
-                               return (Node *) casewhen->result;
+                       if (casecond && IsA(casecond, Const))
+                       {
+                               Const      *const_input = (Const *) casecond;
+
+                               if (const_input->constisnull ||
+                                       !DatumGetBool(const_input->constvalue))
+                                       continue;       /* drop alternative with FALSE condition */
+                               /* Else it's constant TRUE */
+                               const_true_cond = true;
+                       }
+
+                       /* Simplify this alternative's result value */
+                       caseresult =
+                               eval_const_expressions_mutator((Node *) oldcasewhen->result,
+                                                                                          context);
+
+                       /* If non-constant test condition, emit a new WHEN node */
+                       if (!const_true_cond)
+                       {
+                               CaseWhen   *newcasewhen = makeNode(CaseWhen);
 
+                               newcasewhen->expr = (Expr *) casecond;
+                               newcasewhen->result = (Expr *) caseresult;
+                               newargs = lappend(newargs, newcasewhen);
+                               continue;
+                       }
+  
                        /*
-                        * Otherwise, add it to the list, and drop all the rest.
+                        * Found a TRUE condition, so none of the remaining alternatives
+                        * can be reached.  We treat the result as the default result.
                         */
-                       newargs = lappend(newargs, casewhen);
+                       defresult = caseresult;
                        break;
                }
 
-               /* Simplify the default result */
-               defresult = eval_const_expressions_mutator((Node *) caseexpr->defresult,
-                                                                                                  context);
+               /* Simplify the default result, unless we replaced it above */
+               if (!const_true_cond)
+                       defresult =
+                               eval_const_expressions_mutator((Node *) caseexpr->defresult,
+                                                                                          context);
 
-               /*
-                * If no non-FALSE alternatives, CASE reduces to the default
-                * result
-                */
+               context->case_val = save_case_val;
+
+               /* If no non-FALSE alternatives, CASE reduces to the default result */
                if (newargs == NIL)
                        return defresult;
                /* Otherwise we need a new CASE node */
@@ -1629,6 +1697,18 @@ eval_const_expressions_mutator(Node *node,
                newcase->defresult = (Expr *) defresult;
                return (Node *) newcase;
        }
+       if (IsA(node, CaseTestExpr))
+       {
+               /*
+                * If we know a constant test value for the current CASE
+                * construct, substitute it for the placeholder.  Else just
+                * return the placeholder as-is.
+                */
+               if (context->case_val)
+                       return copyObject(context->case_val);
+               else
+                       return copyObject(node);
+       }
        if (IsA(node, ArrayExpr))
        {
                ArrayExpr  *arrayexpr = (ArrayExpr *) node;
@@ -1691,6 +1771,10 @@ eval_const_expressions_mutator(Node *node,
                        newargs = lappend(newargs, e);
                }
 
+               /* If all the arguments were constant null, the result is just null */
+               if (newargs == NIL)
+                       return (Node *) makeNullConst(coalesceexpr->coalescetype);
+
                newcoalesce = makeNode(CoalesceExpr);
                newcoalesce->coalescetype = coalesceexpr->coalescetype;
                newcoalesce->args = newargs;
@@ -1959,6 +2043,13 @@ evaluate_function(Oid funcid, Oid result_type, List *args,
        if (funcform->proretset)
                return NULL;
 
+       /*
+        * Can't simplify if it returns RECORD, since it will be needing an
+        * expected tupdesc which we can't supply here.
+        */
+       if (funcform->prorettype == RECORDOID)
+               return NULL;
+
        /*
         * Check for constant inputs and especially constant-NULL inputs.
         */
@@ -2047,7 +2138,6 @@ inline_function(Oid funcid, Oid result_type, List *args,
                                eval_const_expressions_context *context)
 {
        Form_pg_proc funcform = (Form_pg_proc) GETSTRUCT(func_tuple);
-       bool            polymorphic = false;
        Oid                     argtypes[FUNC_MAX_ARGS];
        char       *src;
        Datum           tmp;
@@ -2088,15 +2178,10 @@ inline_function(Oid funcid, Oid result_type, List *args,
                if (argtypes[i] == ANYARRAYOID ||
                        argtypes[i] == ANYELEMENTOID)
                {
-                       polymorphic = true;
                        argtypes[i] = exprType((Node *) list_nth(args, i));
                }
        }
 
-       if (funcform->prorettype == ANYARRAYOID ||
-               funcform->prorettype == ANYELEMENTOID)
-               polymorphic = true;
-
        /*
         * Setup error traceback support for ereport().  This is so that we
         * can finger the function that bad information came from.
@@ -2169,16 +2254,14 @@ inline_function(Oid funcid, Oid result_type, List *args,
        newexpr = (Node *) ((TargetEntry *) linitial(querytree->targetList))->expr;
 
        /*
-        * If the function has any arguments declared as polymorphic types,
-        * then it wasn't type-checked at definition time; must do so now.
-        * (This will raise an error if wrong, but that's okay since the
-        * function would fail at runtime anyway.  Note we do not try this
-        * until we have verified that no rewriting was needed; that's
-        * probably not important, but let's be careful.)
+        * Make sure the function (still) returns what it's declared to.  This will
+        * raise an error if wrong, but that's okay since the function would fail
+        * at runtime anyway.  Note we do not try this until we have verified that
+        * no rewriting was needed; that's probably not important, but let's be
+        * careful.
         */
-       if (polymorphic)
-               (void) check_sql_fn_retval(result_type, get_typtype(result_type),
-                                                                  querytree_list, NULL);
+       (void) check_sql_fn_retval(result_type, get_typtype(result_type),
+                                                          querytree_list, NULL);
 
        /*
         * Additional validity checks on the expression.  It mustn't return a