From f52a7dee78949190aafe716d01654ee647f3dc61 Mon Sep 17 00:00:00 2001 From: Fabrice Nicol Date: Tue, 1 Jun 2021 04:15:59 +0200 Subject: [PATCH] Add etags support for Mercury (https://mercurylang.org) Tag declarations starting lines with ':-'. By default, all declarations are tagged. Optionally, first predicate or functions in clauses can be tagged as in Prolog support using --declarations (Bug#47408). * lib-src/etags.c (test_objc_is_mercury, Mercury_functions) (mercury_skip_comment, mercury_decl, mercury_pr): Implement Mercury support. As Mercury and Objective-C have same file extension .m, a heuristic test tries to detect the language. If this test fails, --language=mercury should be used. * doc/man/etags.1: Document the change. Add Mercury-specific behavior for --declarations. This option tags first predicates or functions in clauses in addition to declarations. --- doc/man/etags.1 | 23 +- etc/NEWS | 7 + lib-src/etags.c | 490 +++++- test/manual/etags/Makefile | 3 +- test/manual/etags/merc-src/accumulator.m | 1954 ++++++++++++++++++++++ 5 files changed, 2464 insertions(+), 13 deletions(-) create mode 100644 test/manual/etags/merc-src/accumulator.m diff --git a/doc/man/etags.1 b/doc/man/etags.1 index 354f6ca88b..cbd3c1a646 100644 --- a/doc/man/etags.1 +++ b/doc/man/etags.1 @@ -1,5 +1,5 @@ .\" See section COPYING for copyright and redistribution information. -.TH ETAGS 1 "2019-06-24" "GNU Tools" "GNU" +.TH ETAGS 1 "2021-03-30" "GNU Tools" "GNU" .de BP .sp .ti -.2i @@ -50,9 +50,9 @@ format understood by .BR vi ( 1 )\c \&. Both forms of the program understand the syntax of C, Objective C, C++, Java, Fortran, Ada, Cobol, Erlang, -Forth, Go, HTML, LaTeX, Emacs Lisp/Common Lisp, Lua, Makefile, Pascal, Perl, -Ruby, Rust, PHP, PostScript, Python, Prolog, Scheme and -most assembler\-like syntaxes. +Forth, Go, HTML, LaTeX, Emacs Lisp/Common Lisp, Lua, Makefile, Mercury, Pascal, +Perl, Ruby, Rust, PHP, PostScript, Python, Prolog, Scheme and most +assembler\-like syntaxes. Both forms read the files specified on the command line, and write a tag table (defaults: \fBTAGS\fP for \fBetags\fP, \fBtags\fP for \fBctags\fP) in the current working directory. @@ -91,6 +91,9 @@ Only \fBctags\fP accepts this option. In C and derived languages, create tags for function declarations, and create tags for extern variables unless \-\-no\-globals is used. In Lisp, create tags for (defvar foo) declarations. +In Mercury, declarations start a line with "\|\fB:-\fP\|" and are always +tagged. In addition, this option tags predicates or functions in first +rules of clauses, as in Prolog. .TP .B \-D, \-\-no\-defines Do not create tag entries for C preprocessor constant definitions @@ -125,10 +128,14 @@ final brace of a function or structure definition in C and C++. Parse the following files according to the given language. More than one such options may be intermixed with filenames. Use \fB\-\-help\fP to get a list of the available languages and their default filename -extensions. The "auto" language can be used to restore automatic -detection of language based on the file name. The "none" -language may be used to disable language parsing altogether; only -regexp matching is done in this case (see the \fB\-\-regex\fP option). +extensions. For example, as Mercury and Objective-C have same +filename extension \fI.m\fP, a test based on contents tries to detect +the language. If this test fails, \fB\-\-language=\fP\fImercury\fP or +\fB\-\-language=\fP\fIobjc\fP should be used. +The "auto" language can be used to restore automatic detection of language +based on the file name. The "none" language may be used to disable language +parsing altogether; only regexp matching is done in this case (see the +\fB\-\-regex\fP option). .TP .B \-\-members Create tag entries for variables that are members of structure-like diff --git a/etc/NEWS b/etc/NEWS index 6622861aaf..c9c5c97719 100644 --- a/etc/NEWS +++ b/etc/NEWS @@ -111,6 +111,13 @@ filters. * Changes in Emacs 28.1 ++++ +** Etags support for the Mercury programming language (https://mercurylang.org). +** Etags command line option --declarations now has Mercury-specific behavior. +All Mercury declarations are tagged by default. +For compatibility with Prolog etags support, predicates and functions appearing +first in clauses will also be tagged if etags is run with '--declarations'. + +++ ** New command 'font-lock-update', bound to 'C-x x f'. This command updates the syntax highlighting in this buffer. diff --git a/lib-src/etags.c b/lib-src/etags.c index d703183cef..ac1fbb4df5 100644 --- a/lib-src/etags.c +++ b/lib-src/etags.c @@ -142,7 +142,14 @@ Copyright (C) 1984, 1987-1989, 1993-1995, 1998-2021 Free Software # define CTAGS false #endif -/* Copy to DEST from SRC (containing LEN bytes), and append a NUL byte. */ +/* Define MERCURY_HEURISTICS_RATIO as it was necessary to disambiguate + Mercury from Objective C, which have same file extensions .m + See comments before function test_objc_is_mercury for details. */ +#ifndef MERCURY_HEURISTICS_RATIO +# define MERCURY_HEURISTICS_RATIO 0.5 +#endif + +/* COPY to DEST from SRC (containing LEN bytes), and append a NUL byte. */ static void memcpyz (void *dest, void const *src, ptrdiff_t len) { @@ -359,6 +366,7 @@ #define xrnew(op, n, m) ((op) = xnrealloc (op, n, (m) * sizeof *(op))) static void Lisp_functions (FILE *); static void Lua_functions (FILE *); static void Makefile_targets (FILE *); +static void Mercury_functions (FILE *); static void Pascal_functions (FILE *); static void Perl_functions (FILE *); static void PHP_functions (FILE *); @@ -379,6 +387,7 @@ #define xrnew(op, n, m) ((op) = xnrealloc (op, n, (m) * sizeof *(op))) static bool nocase_tail (const char *); static void get_tag (char *, char **); static void get_lispy_tag (char *); +static void test_objc_is_mercury (char *, language **); static void analyze_regex (char *); static void free_regexps (void); @@ -684,10 +693,22 @@ #define STDIN 0x1001 /* returned by getopt_long on --parse-stdin */ "In makefiles, targets are tags; additionally, variables are tags\n\ unless you specify '--no-globals'."; +/* Mercury and Objective C share the same .m file extensions. */ +static const char *Mercury_suffixes [] = + {"m", + NULL}; +static const char Mercury_help [] = + "In Mercury code, tags are all declarations beginning a line with ':-'\n\ +and optionally Prolog-like definitions (first rule for a predicate or \ +function).\n\ +To enable this behavior, run etags using --declarations."; +static bool with_mercury_definitions = false; +float mercury_heuristics_ratio = MERCURY_HEURISTICS_RATIO; + static const char *Objc_suffixes [] = - { "lm", /* Objective lex file */ - "m", /* Objective C file */ - NULL }; + { "lm", /* Objective lex file */ + "m", /* By default, Objective C file will be assumed. */ + NULL}; static const char Objc_help [] = "In Objective C code, tags include Objective C definitions for classes,\n\ class categories, methods and protocols. Tags for variables and\n\ @@ -831,7 +852,9 @@ #define STDIN 0x1001 /* returned by getopt_long on --parse-stdin */ { "lisp", Lisp_help, Lisp_functions, Lisp_suffixes }, { "lua", Lua_help,Lua_functions,Lua_suffixes,NULL,Lua_interpreters}, { "makefile", Makefile_help,Makefile_targets,NULL,Makefile_filenames}, + /* objc listed before mercury as it is a better default for .m extensions. */ { "objc", Objc_help, plain_C_entries, Objc_suffixes }, + { "mercury", Mercury_help, Mercury_functions, Mercury_suffixes }, { "pascal", Pascal_help, Pascal_functions, Pascal_suffixes }, { "perl",Perl_help,Perl_functions,Perl_suffixes,NULL,Perl_interpreters}, { "php", PHP_help, PHP_functions, PHP_suffixes }, @@ -958,6 +981,9 @@ print_help (argument *argbuffer) puts ("\tand create tags for extern variables unless --no-globals is used."); + puts ("In Mercury, tag both declarations starting a line with ':-' and first\n\ + predicates or functions in clauses."); + if (CTAGS) puts ("-d, --defines\n\ Create tag entries for C #define constants and enum constants, too."); @@ -1783,6 +1809,11 @@ find_entries (FILE *inf) if (parser == NULL) { lang = get_language_from_filename (curfdp->infname, true); + + /* Disambiguate file names between Objc and Mercury */ + if (lang != NULL && strcmp (lang->name, "objc") == 0) + test_objc_is_mercury (curfdp->infname, &lang); + if (lang != NULL && lang->function != NULL) { curfdp->lang = lang; @@ -6070,6 +6101,457 @@ prolog_atom (char *s, size_t pos) return 0; } + +/* + * Support for Mercury + * + * Assumes that the declarationa starts at column 0. + * Original code by Sunichirou Sugou (1989) for Prolog. + * Rewritten by Anders Lindgren (1996) for Prolog. + * Adapted by Fabrice Nicol (2021) for Mercury. + * Note: Prolog-support behavior is preserved if + * --declarations is used, corresponding to + * with_mercury_definitions=true. + */ + +static ptrdiff_t mercury_pr (char *, char *, ptrdiff_t); +static void mercury_skip_comment (linebuffer *, FILE *); +static bool is_mercury_type = false; +static bool is_mercury_quantifier = false; +static bool is_mercury_declaration = false; + +/* + * Objective-C and Mercury have identical file extension .m + * To disambiguate between Objective C and Mercury, parse file + * with the following heuristics hook: + * - if line starts with :- choose Mercury unconditionally, + * - if line starts with #, @, choose Objective-C, + * - otherwise compute the following ratio: + * + * r = (number of lines with :- + * or % in non-commented parts or . at trimmed EOL) + * / (number of lines - number of lines starting by any amount + * of whitespace, optionally followed by comment(s)) + * + * Note: strings are neglected in counts. + * + * If r > mercury_heuristics_ratio, choose Mercury. + * Experimental tests show that a possibly optimal default value for + * this floor value is around 0.5. This is the default value for + * MERCURY_HEURISTICS_RATIO, defined in the first lines of this file. + * The closer r to 0.5, the closer the source code to pure Prolog. + * Idiomatic Mercury is scored either with r = 1.0 or higher. + * Objective-C is scored with r = 0.0. When this fails, the r-score never + * rose above 0.1 in Objective-C tests. + */ + +static void +test_objc_is_mercury (char *this_file, language **lang) +{ + if (this_file == NULL) return; + FILE* fp = fopen (this_file, "r"); + if (fp == NULL) + pfatal (this_file); + + bool blank_line = false; /* Line starting with any amount of white space + followed by optional comment(s). */ + bool commented_line = false; + bool found_dot = false; + bool only_space_before = true; + bool start_of_line = true; + int c; + intmax_t lines = 1; + intmax_t mercury_dots = 0; + intmax_t percentage_signs = 0; + intmax_t rule_signs = 0; + float ratio = 0; + + while ((c = fgetc (fp)) != EOF) + { + switch (c) + { + case '\n': + if (! blank_line) ++lines; + blank_line = true; + commented_line = false; + start_of_line = true; + if (found_dot) ++mercury_dots; + found_dot = false; + only_space_before = true; + break; + case '.': + found_dot = ! commented_line; + only_space_before = false; + break; + case '%': /* More frequent in Mercury. May be modulo in Obj.-C. */ + if (! commented_line) + { + ++percentage_signs; + /* Cannot tell if it is a comment or modulo yet for sure. + Yet works for heuristic purposes. */ + commented_line = true; + } + found_dot = false; + start_of_line = false; + only_space_before = false; + break; + case '/': + { + int d = fgetc (fp); + found_dot = false; + only_space_before = false; + if (! commented_line) + { + if (d == '*') + commented_line = true; + else + /* If d == '/', cannot tell if it is an Obj.-C comment: + may be Mercury integ. division. */ + blank_line = false; + } + } + FALLTHROUGH; + case ' ': + case '\t': + start_of_line = false; + break; + case ':': + c = fgetc (fp); + if (start_of_line) + { + if (c == '-') + { + ratio = 1.0; /* Failsafe, not an operator in Obj.-C. */ + goto out; + } + start_of_line = false; + } + else + { + /* p :- q. Frequent in Mercury. + Rare or in quoted exprs in Obj.-C. */ + if (c == '-' && ! commented_line) + ++rule_signs; + } + blank_line = false; + found_dot = false; + only_space_before = false; + break; + case '@': + case '#': + if (start_of_line || only_space_before) + { + ratio = 0.0; + goto out; + } + FALLTHROUGH; + default: + start_of_line = false; + blank_line = false; + found_dot = false; + only_space_before = false; + } + } + + /* Fallback heuristic test. Not failsafe but errless in pratice. */ + ratio = ((float) rule_signs + percentage_signs + mercury_dots) / lines; + + out: + if (fclose (fp) == EOF) + pfatal (this_file); + + if (ratio > mercury_heuristics_ratio) + { + /* Change the language from Objective C to Mercury. */ + static language lang0 = { "mercury", Mercury_help, Mercury_functions, + Mercury_suffixes }; + *lang = &lang0; + } +} + +static void +Mercury_functions (FILE *inf) +{ + char *cp, *last = NULL; + ptrdiff_t lastlen = 0, allocated = 0; + if (declarations) with_mercury_definitions = true; + + LOOP_ON_INPUT_LINES (inf, lb, cp) + { + if (cp[0] == '\0') /* Empty line. */ + continue; + else if (c_isspace (cp[0]) || cp[0] == '%') + /* A Prolog-type comment or anything other than a declaration. */ + continue; + else if (cp[0] == '/' && cp[1] == '*') /* Mercury C-type comment. */ + mercury_skip_comment (&lb, inf); + else + { + is_mercury_declaration = (cp[0] == ':' && cp[1] == '-'); + + if (is_mercury_declaration + || with_mercury_definitions) + { + ptrdiff_t len = mercury_pr (cp, last, lastlen); + if (0 < len) + { + /* Store the declaration to avoid generating duplicate + tags later. */ + if (allocated <= len) + { + xrnew (last, len + 1, 1); + allocated = len + 1; + } + memcpyz (last, cp, len); + lastlen = len; + } + } + } + } + free (last); +} + +static void +mercury_skip_comment (linebuffer *plb, FILE *inf) +{ + char *cp; + + do + { + for (cp = plb->buffer; *cp != '\0'; ++cp) + if (cp[0] == '*' && cp[1] == '/') + return; + readline (plb, inf); + } + while (perhaps_more_input (inf)); +} + +/* + * A declaration is added if it matches: + * :-( + * If with_mercury_definitions == true, we also add: + * ( + * or :- + * As for Prolog support, different arities and types are not taken into + * consideration. + * Item is added to the tags database if it doesn't match the + * name of the previous declaration. + * + * Consume a Mercury declaration. + * Return the number of bytes consumed, or 0 if there was an error. + * + * A Mercury declaration must be one of: + * :- type + * :- solver type + * :- pred + * :- func + * :- inst + * :- mode + * :- typeclass + * :- instance + * :- pragma + * :- promise + * :- initialise + * :- finalise + * :- mutable + * :- module + * :- interface + * :- implementation + * :- import_module + * :- use_module + * :- include_module + * :- end_module + * followed on the same line by an alphanumeric sequence, starting with a lower + * case letter or by a single-quoted arbitrary string. + * Single quotes can escape themselves. Backslash quotes everything. + * + * Return the size of the name of the declaration or 0 if no header was found. + * As quantifiers may precede functions or predicates, we must list them too. + */ + +static const char *Mercury_decl_tags[] = {"type", "solver type", "pred", + "func", "inst", "mode", "typeclass", "instance", "pragma", "promise", + "initialise", "finalise", "mutable", "module", "interface", "implementation", + "import_module", "use_module", "include_module", "end_module", "some", "all"}; + +static size_t +mercury_decl (char *s, size_t pos) +{ + if (s == NULL) return 0; + + size_t origpos; + origpos = pos; + + while (s + pos != NULL && (c_isalnum (s[pos]) || s[pos] == '_')) ++pos; + + unsigned char decl_type_length = pos - origpos; + char buf[decl_type_length + 1]; + memset (buf, 0, decl_type_length + 1); + + /* Mercury declaration tags. Consume them, then check the declaration item + following :- is legitimate, then go on as in the prolog case. */ + + memcpy (buf, &s[origpos], decl_type_length); + + bool found_decl_tag = false; + + if (is_mercury_quantifier) + { + if (strcmp (buf, "pred") != 0 && strcmp (buf, "func") != 0) /* Bad syntax. */ + return 0; + is_mercury_quantifier = false; /* Beset to base value. */ + found_decl_tag = true; + } + else + { + for (int j = 0; j < sizeof (Mercury_decl_tags) / sizeof (char*); ++j) + { + if (strcmp (buf, Mercury_decl_tags[j]) == 0) + { + found_decl_tag = true; + if (strcmp (buf, "type") == 0) + is_mercury_type = true; + + if (strcmp (buf, "some") == 0 + || strcmp (buf, "all") == 0) + { + is_mercury_quantifier = true; + } + + break; /* Found declaration tag of rank j. */ + } + else + /* 'solver type' has a blank in the middle, + so this is the hard case. */ + if (strcmp (buf, "solver") == 0) + { + ++pos; + while (s + pos != NULL && (c_isalnum (s[pos]) || s[pos] == '_')) + ++pos; + + decl_type_length = pos - origpos; + char buf2[decl_type_length + 1]; + memset (buf2, 0, decl_type_length + 1); + memcpy (buf2, &s[origpos], decl_type_length); + + if (strcmp (buf2, "solver type") == 0) + { + found_decl_tag = false; + break; /* Found declaration tag of rank j. */ + } + } + } + } + + /* If with_mercury_definitions == false + * this is a Mercury syntax error, ignoring... */ + + if (with_mercury_definitions) + { + if (found_decl_tag) + pos = skip_spaces (s + pos) - s; /* Skip len blanks again. */ + else + /* Prolog-like behavior + * we have parsed the predicate once, yet inappropriately + * so restarting again the parsing step. */ + pos = 0; + } + else + { + if (found_decl_tag) + pos = skip_spaces (s + pos) - s; /* Skip len blanks again. */ + else + return 0; + } + + /* From now on it is the same as for Prolog except for module dots. */ + + if (c_islower (s[pos]) || s[pos] == '_' ) + { + /* The name is unquoted. + Do not confuse module dots with end-of-declaration dots. */ + + while (c_isalnum (s[pos]) + || s[pos] == '_' + || (s[pos] == '.' /* A module dot. */ + && s + pos + 1 != NULL + && (c_isalnum (s[pos + 1]) || s[pos + 1] == '_'))) + ++pos; + + return pos - origpos; + } + else if (s[pos] == '\'') + { + ++pos; + for (;;) + { + if (s[pos] == '\'') + { + ++pos; + if (s[pos] != '\'') + break; + ++pos; /* A double quote. */ + } + else if (s[pos] == '\0') /* Multiline quoted atoms are ignored. */ + return 0; + else if (s[pos] == '\\') + { + if (s[pos+1] == '\0') + return 0; + pos += 2; + } + else + ++pos; + } + return pos - origpos; + } + else if (is_mercury_quantifier && s[pos] == '[') /* :- some [T] pred/func. */ + { + for (++pos; s + pos != NULL && s[pos] != ']'; ++pos) {} + if (s + pos == NULL) return 0; + ++pos; + pos = skip_spaces (s + pos) - s; + return mercury_decl (s, pos) + pos - origpos; + } + else + return 0; +} + +static ptrdiff_t +mercury_pr (char *s, char *last, ptrdiff_t lastlen) +{ + size_t len0 = 0; + is_mercury_type = false; + is_mercury_quantifier = false; + + if (is_mercury_declaration) + { + /* Skip len0 blanks only for declarations. */ + len0 = skip_spaces (s + 2) - s; + } + + size_t len = mercury_decl (s , len0); + if (len == 0) return 0; + len += len0; + + if (( (s[len] == '.' /* This is a statement dot, not a module dot. */ + || (s[len] == '(' && (len += 1)) + || (s[len] == ':' /* Stopping in case of a rule. */ + && s[len + 1] == '-' + && (len += 2))) + && (lastlen != len || memcmp (s, last, len) != 0) + ) + /* Types are often declared on several lines so keeping just + the first line. */ + || is_mercury_type) + { + make_tag (s, 0, true, s, len, lineno, linecharno); + return len; + } + + return 0; +} + /* * Support for Erlang diff --git a/test/manual/etags/Makefile b/test/manual/etags/Makefile index 8d56db29b7..b3a82fdba8 100644 --- a/test/manual/etags/Makefile +++ b/test/manual/etags/Makefile @@ -16,6 +16,7 @@ HTMLSRC= #JAVASRC=$(addprefix ./java-src/, ) LUASRC=$(addprefix ./lua-src/,allegro.lua test.lua) MAKESRC=$(addprefix ./make-src/,Makefile) +MERCSRC=$(addprefix ./merc-src/,accumulator.m) OBJCSRC=$(addprefix ./objc-src/,Subprocess.h Subprocess.m PackInsp.h PackInsp.m) OBJCPPSRC=$(addprefix ./objcpp-src/,SimpleCalc.H SimpleCalc.M) PASSRC=$(addprefix ./pas-src/,common.pas) @@ -32,7 +33,7 @@ YSRC= SRCS=${ADASRC} ${ASRC} ${CSRC} ${CPSRC} ${ELSRC} ${ERLSRC} ${FSRC}\ ${FORTHSRC} ${GOSRC} ${HTMLSRC} ${JAVASRC} ${LUASRC} ${MAKESRC}\ ${OBJCSRC} ${OBJCPPSRC} ${PASSRC} ${PHPSRC} ${PERLSRC} ${PSSRC}\ - ${PROLSRC} ${PYTSRC} ${RBSRC} ${RSSRC} ${SCMSRC} ${TEXSRC} ${YSRC} + ${PROLSRC} ${PYTSRC} ${RBSRC} ${RSSRC} ${SCMSRC} ${TEXSRC} ${YSRC} ${MERCSRC} NONSRCS=./f-src/entry.strange ./erl-src/lists.erl ./cp-src/clheir.hpp.gz ETAGS_PROG=../../../lib-src/etags diff --git a/test/manual/etags/merc-src/accumulator.m b/test/manual/etags/merc-src/accumulator.m new file mode 100644 index 0000000000..94a6b1d858 --- /dev/null +++ b/test/manual/etags/merc-src/accumulator.m @@ -0,0 +1,1954 @@ +%---------------------------------------------------------------------------% +% vim: ft=mercury ts=4 sw=4 et +%---------------------------------------------------------------------------% +% Copyright (C) 1999-2000,2002-2007, 2009-2012 The University of Melbourne. +% Copyright (C) 2015 The Mercury team. +% This file may only be copied under the terms of the GNU General +% Public License - see the file COPYING in the Mercury distribution. +%---------------------------------------------------------------------------% +% +% Module: accumulator.m. +% Main authors: petdr. +% +% Attempts to transform a single proc to a tail recursive form by +% introducing accumulators. The algorithm can do this if the code after +% the recursive call has either the order independent state update or +% associative property. +% +% /* Order independent State update property */ +% :- promise all [A,B,S0,S] +% ( +% (some[SA] (update(A, S0, SA), update(B, SA, S))) +% <=> +% (some[SB] (update(B, S0, SB), update(A, SB, S))) +% ). +% +% /* Associativity property */ +% :- promise all [A,B,C,ABC] +% ( +% (some[AB] (assoc(A, B, AB), assoc(AB, C, ABC))) +% <=> +% (some[BC] (assoc(B, C, BC), assoc(A, BC, ABC))) +% ). +% +% XXX What about exceptions and non-termination? +% +% The promise declarations above only provide promises about the declarative +% semantics, but in order to apply this optimization, we ought to check that +% it will preserve the operational semantics (modulo whatever changes are +% allowed by the language semantics options). +% +% Currently we check and respect the --fully-strict option, but not the +% --no-reorder-conj option. XXX we should check --no-reorder-conj! +% If --no-reorder-conj was set, it would still be OK to apply this +% transformation, but ONLY in cases where the goals which get reordered +% are guaranteed not to throw any exceptions. +% +% The algorithm implemented is a combination of the algorithms from +% "Making Mercury Programs Tail Recursive" and +% "State Update Transformation", which can be found at +% . +% +% Note that currently "State Update Transformation" paper only resides +% in CVS papers archive in the directory update, but has been submitted +% to PPDP '00. +% +% The transformation recognises predicates in the form +% +% p(In, OutUpdate, OutAssoc) :- +% minimal(In), +% initialize(OutUpdate), +% base(OutAssoc). +% p(In, OutUpdate, OutAssoc) :- +% decompose(In, Current, Rest), +% p(Rest, OutUpdate0, OutAssoc0), +% update(Current, OutUpdate0, OutUpdate), +% assoc(Current, OutAssoc0, OutAssoc). +% +% which can be transformed by the algorithm in "State Update Transformation" to +% +% p(In, OutUpdate, OutAssoc) :- +% initialize(AccUpdate), +% p_acc(In, OutUpdate, OutAssoc, AccUpdate). +% +% p_acc(In, OutUpdate, OutAssoc, AccUpdate) :- +% minimal(In), +% base(OutAssoc), +% OutUpdate = AccUpdate. +% p_acc(In, OutUpdate, OutAssoc, AccUpdate0) :- +% decompose(In, Current, Rest), +% update(Current, AccUpdate0, AccUpdate), +% p_acc(Rest, OutUpdate, OutAssoc0, AccUpdate), +% assoc(Current, OutAssoc0, OutAssoc). +% +% we then apply the algorithm from "Making Mercury Programs Tail Recursive" +% to p_acc to obtain +% +% p_acc(In, OutUpdate, OutAssoc, AccUpdate) :- +% minimal(In), +% base(OutAssoc), +% OutUpdate = AccUpdate. +% p_acc(In, OutUpdate, OutAssoc, AccUpdate0) :- +% decompose(In, Current, Rest), +% update(Current, AccUpdate0, AccUpdate), +% p_acc2(Rest, OutUpdate, OutAssoc, AccUpdate, Current). +% +% p_acc2(In, OutUpdate, OutAssoc, AccUpdate0, AccAssoc0) :- +% minimal(In), +% base(Base), +% assoc(AccAssoc0, Base, OutAssoc), +% OutUpdate = AccUpdate0. +% p_acc2(In, OutUpdate, OutAssoc, AccUpdate0, AccAssoc0) :- +% decompose(In, Current, Rest), +% update(Current, AccUpdate0, AccUpdate), +% assoc(AccAssoc0, Current, AccAssoc), +% p_acc2(Rest, OutUpdate, OutAssoc, AccUpdate, AccAssoc). +% +% p_acc is no longer recursive and is only ever called from p, so we +% inline p_acc into p to obtain the final schema. +% +% p(In, OutUpdate, OutAssoc) :- +% minimal(In), +% base(OutAssoc), +% initialize(AccUpdate), +% OutUpdate = AccUpdate. +% p(In, OutUpdate, OutAssoc) :- +% decompose(In, Current, Rest), +% initialize(AccUpdate0), +% update(Current, AccUpdate0, AccUpdate), +% p_acc2(Rest, OutUpdate, OutAssoc, AccUpdate, Current). +% +% p_acc2(In, OutUpdate, OutAssoc, AccUpdate0, AccAssoc0) :- +% minimal(In), +% base(Base), +% assoc(AccAssoc0, Base, OutAssoc), +% OutUpdate = AccUpdate0. +% p_acc2(In, OutUpdate, OutAssoc, AccUpdate0, AccAssoc0) :- +% decompose(In, Current, Rest), +% update(Current, AccUpdate0, AccUpdate), +% assoc(AccAssoc0, Current, AccAssoc), +% p_acc2(Rest, OutUpdate, OutAssoc, AccUpdate, AccAssoc). +% +% The only real difficulty in this new transformation is identifying the +% initialize/1 and base/1 goals from the original base case. +% +% Note that if the recursive clause contains multiple calls to p, the +% transformation attempts to move each recursive call to the end +% until one succeeds. This makes the order of independent recursive +% calls in the body irrelevant. +% +% XXX Replace calls to can_reorder_goals with calls to the version that +% use the intermodule-analysis framework. +% +%---------------------------------------------------------------------------% + +:- module transform_hlds.accumulator. +:- interface. + +:- import_module hlds. +:- import_module hlds.hlds_module. +:- import_module hlds.hlds_pred. + +:- import_module univ. + + % Attempt to transform a procedure into accumulator recursive form. + % If we succeed, we will add the recursive version of the procedure + % to the module_info. However, we may also encounter errors, which + % we will add to the list of error_specs in the univ accumulator. + % +:- pred accu_transform_proc(pred_proc_id::in, pred_info::in, + proc_info::in, proc_info::out, module_info::in, module_info::out, + univ::in, univ::out) is det. + +%---------------------------------------------------------------------------% +%---------------------------------------------------------------------------% + +:- implementation. + +:- import_module hlds.assertion. +:- import_module hlds.goal_util. +:- import_module hlds.hlds_error_util. +:- import_module hlds.hlds_goal. +:- import_module hlds.hlds_out. +:- import_module hlds.hlds_out.hlds_out_util. +:- import_module hlds.hlds_promise. +:- import_module hlds.instmap. +:- import_module hlds.pred_table. +:- import_module hlds.quantification. +:- import_module hlds.status. +:- import_module hlds.vartypes. +:- import_module libs. +:- import_module libs.globals. +:- import_module libs.optimization_options. +:- import_module libs.options. +:- import_module mdbcomp. +:- import_module mdbcomp.sym_name. +:- import_module parse_tree. +:- import_module parse_tree.error_util. +:- import_module parse_tree.prog_data. +:- import_module parse_tree.prog_mode. +:- import_module parse_tree.prog_util. +:- import_module parse_tree.set_of_var. +:- import_module transform_hlds.goal_store. + +:- import_module assoc_list. +:- import_module bool. +:- import_module int. +:- import_module io. +:- import_module list. +:- import_module map. +:- import_module maybe. +:- import_module pair. +:- import_module require. +:- import_module set. +:- import_module solutions. +:- import_module string. +:- import_module term. +:- import_module varset. + +%---------------------------------------------------------------------------% + + % The form of the goal around the base and recursive cases. + % +:- type top_level + ---> switch_base_rec + ; switch_rec_base + ; disj_base_rec + ; disj_rec_base + ; ite_base_rec + ; ite_rec_base. + + % An accu_goal_id represents a goal. The first field says which conjunction + % the goal came from (the base case or the recursive case), and the second + % gives the location of the goal in that conjunction. + % +:- type accu_goal_id + ---> accu_goal_id(accu_case, int). + +:- type accu_case + ---> accu_base + ; accu_rec. + + % The goal_store associates a goal with each goal_id. + % +:- type accu_goal_store == goal_store(accu_goal_id). + + % A substitution from the first variable name to the second. + % +:- type accu_subst == map(prog_var, prog_var). + +:- type accu_warning + ---> accu_warn(prog_context, pred_id, prog_var, prog_var). + % Warn that two prog_vars in a call to pred_id at the given context + % were swapped, which may cause an efficiency problem. + +%---------------------------------------------------------------------------% + +accu_transform_proc(proc(PredId, ProcId), PredInfo, !ProcInfo, !ModuleInfo, + !Cookie) :- + module_info_get_globals(!.ModuleInfo, Globals), + globals.get_opt_tuple(Globals, OptTuple), + DoLCMC = OptTuple ^ ot_opt_lcmc_accumulator, + globals.lookup_bool_option(Globals, fully_strict, FullyStrict), + ( if + should_attempt_accu_transform(!ModuleInfo, PredId, ProcId, PredInfo, + !ProcInfo, FullyStrict, DoLCMC, Warnings) + then + globals.lookup_bool_option(Globals, very_verbose, VeryVerbose), + ( + VeryVerbose = yes, + trace [io(!IO)] ( + module_info_get_name(!.ModuleInfo, ModuleName), + get_progress_output_stream(Globals, ModuleName, + ProgressStream, !IO), + PredStr = pred_id_to_string(!.ModuleInfo, PredId), + io.format(ProgressStream, + "%% Accumulators introduced into %s\n", [s(PredStr)], !IO) + ) + ; + VeryVerbose = no + ), + + ( + Warnings = [] + ; + Warnings = [_ | _], + pred_info_get_context(PredInfo, Context), + PredPieces = describe_one_pred_name(!.ModuleInfo, + should_module_qualify, PredId), + InPieces = [words("In") | PredPieces] ++ [suffix(":"), nl], + InMsg = simple_msg(Context, + [option_is_set(warn_accumulator_swaps, yes, + [always(InPieces)])]), + + proc_info_get_varset(!.ProcInfo, VarSet), + generate_warnings(!.ModuleInfo, VarSet, Warnings, WarnMsgs), + ( + Warnings = [_], + EnsurePieces = [words("Please ensure that this"), + words("argument rearrangement does not introduce"), + words("performance problems.")] + ; + Warnings = [_, _ | _], + EnsurePieces = [words("Please ensure that these"), + words("argument rearrangements do not introduce"), + words("performance problems.")] + ), + SuppressPieces = + [words("These warnings can be suppressed by"), + quote("--no-warn-accumulator-swaps"), suffix(".")], + VerbosePieces = [words("If a predicate has been declared"), + words("associative"), + words("via a"), quote("promise"), words("declaration,"), + words("the compiler will rearrange the order of"), + words("the arguments in calls to that predicate,"), + words("if by so doing it makes the containing predicate"), + words("tail recursive. In such situations, the compiler"), + words("will issue this warning. If this reordering"), + words("changes the performance characteristics"), + words("of the call to the predicate, use"), + quote("--no-accumulator-introduction"), + words("to turn the optimization off, or "), + quote("--no-warn-accumulator-swaps"), + words("to turn off the warnings.")], + EnsureSuppressMsg = simple_msg(Context, + [option_is_set(warn_accumulator_swaps, yes, + [always(EnsurePieces), always(SuppressPieces)]), + verbose_only(verbose_once, VerbosePieces)]), + Severity = severity_conditional(warn_accumulator_swaps, yes, + severity_warning, no), + Msgs = [InMsg | WarnMsgs] ++ [EnsureSuppressMsg], + Spec = error_spec($pred, Severity, phase_accumulator_intro, Msgs), + + det_univ_to_type(!.Cookie, Specs0), + Specs = [Spec | Specs0], + type_to_univ(Specs, !:Cookie) + ) + else + true + ). + +%---------------------------------------------------------------------------% +%---------------------------------------------------------------------------% + +:- pred generate_warnings(module_info::in, prog_varset::in, + list(accu_warning)::in, list(error_msg)::out) is det. + +generate_warnings(_, _, [], []). +generate_warnings(ModuleInfo, VarSet, [Warning | Warnings], [Msg | Msgs]) :- + generate_warning(ModuleInfo, VarSet, Warning, Msg), + generate_warnings(ModuleInfo, VarSet, Warnings, Msgs). + +:- pred generate_warning(module_info::in, prog_varset::in, accu_warning::in, + error_msg::out) is det. + +generate_warning(ModuleInfo, VarSet, Warning, Msg) :- + Warning = accu_warn(Context, PredId, VarA, VarB), + PredPieces = describe_one_pred_name(ModuleInfo, should_module_qualify, + PredId), + + varset.lookup_name(VarSet, VarA, VarAName), + varset.lookup_name(VarSet, VarB, VarBName), + + Pieces = [words("warning: the call to")] ++ PredPieces ++ + [words("has had the location of the variables"), + quote(VarAName), words("and"), quote(VarBName), + words("swapped to allow accumulator introduction."), nl], + Msg = simplest_msg(Context, Pieces). + +%---------------------------------------------------------------------------% +%---------------------------------------------------------------------------% + + % should_attempt_accu_transform is only true iff the current proc + % has been transformed to call the newly created accumulator proc. + % +:- pred should_attempt_accu_transform(module_info::in, module_info::out, + pred_id::in, proc_id::in, pred_info::in, proc_info::in, proc_info::out, + bool::in, maybe_opt_lcmc_accumulator::in, + list(accu_warning)::out) is semidet. + +should_attempt_accu_transform(!ModuleInfo, PredId, ProcId, PredInfo, + !ProcInfo, FullyStrict, DoLCMC, Warnings) :- + proc_info_get_goal(!.ProcInfo, Goal0), + proc_info_get_headvars(!.ProcInfo, HeadVars), + proc_info_get_initial_instmap(!.ModuleInfo, !.ProcInfo, InitialInstMap), + accu_standardize(Goal0, Goal), + identify_goal_type(PredId, ProcId, Goal, InitialInstMap, + TopLevel, Base, BaseInstMap, Rec, RecInstMap), + + C = initialize_goal_store(Rec, RecInstMap, Base, BaseInstMap), + identify_recursive_calls(PredId, ProcId, C, RecCallIds), + list.length(Rec, M), + + should_attempt_accu_transform_2(!ModuleInfo, PredId, PredInfo, !ProcInfo, + HeadVars, InitialInstMap, TopLevel, FullyStrict, DoLCMC, + RecCallIds, C, M, Rec, Warnings). + + % should_attempt_accu_transform_2 takes a list of locations of the + % recursive calls, and attempts to introduce accumulator into each of the + % recursive calls, stopping at the first one that succeeds. + % This catches the following case, as selecting the first recursive call + % allows the second recursive call to be moved before it, and + % OutA is in the correct spot in list.append. + % + % p(InA, OutA), + % p(InB, OutB), + % list.append(OutB, OutA, Out) + % +:- pred should_attempt_accu_transform_2(module_info::in, module_info::out, + pred_id::in, pred_info::in, proc_info::in, proc_info::out, + list(prog_var)::in, instmap::in, top_level::in, bool::in, + maybe_opt_lcmc_accumulator::in, + list(accu_goal_id)::in, accu_goal_store::in, int::in, list(hlds_goal)::in, + list(accu_warning)::out) is semidet. + +should_attempt_accu_transform_2(!ModuleInfo, PredId, PredInfo, !ProcInfo, + HeadVars, InitialInstMap, TopLevel, FullyStrict, DoLCMC, + [Id | Ids], C, M, Rec, Warnings) :- + proc_info_get_vartypes(!.ProcInfo, VarTypes0), + identify_out_and_out_prime(!.ModuleInfo, VarTypes0, InitialInstMap, + Id, Rec, HeadVars, Out, OutPrime, HeadToCallSubst, CallToHeadSubst), + ( if + accu_stage1(!.ModuleInfo, VarTypes0, FullyStrict, DoLCMC, Id, M, C, + Sets), + accu_stage2(!.ModuleInfo, !.ProcInfo, Id, C, Sets, OutPrime, Out, + VarSet, VarTypes, Accs, BaseCase, BasePairs, Substs, CS, + WarningsPrime), + accu_stage3(Id, Accs, VarSet, VarTypes, C, CS, Substs, + HeadToCallSubst, CallToHeadSubst, BaseCase, BasePairs, Sets, Out, + TopLevel, PredId, PredInfo, !ProcInfo, !ModuleInfo) + then + Warnings = WarningsPrime + else + should_attempt_accu_transform_2(!ModuleInfo, PredId, PredInfo, + !ProcInfo, HeadVars, InitialInstMap, TopLevel, FullyStrict, DoLCMC, + Ids, C, M, Rec, Warnings) + ). + +%---------------------------------------------------------------------------% +%---------------------------------------------------------------------------% + + % Transform the goal into a standard form that is amenable to + % introducing accumulators. + % + % At the moment all this does is remove any extra disj/conj wrappers + % around the top level goal. + % + % Future work is for this code to rearrange code with multiple base + % and recursive cases into a single base and recursive case. + % +:- pred accu_standardize(hlds_goal::in, hlds_goal::out) is det. + +accu_standardize(Goal0, Goal) :- + ( if + Goal0 = hlds_goal(GoalExpr0, _), + ( + GoalExpr0 = conj(plain_conj, [Goal1]) + ; + GoalExpr0 = disj([Goal1]) + ) + then + accu_standardize(Goal1, Goal) + else + Goal = Goal0 + ). + +%---------------------------------------------------------------------------% +%---------------------------------------------------------------------------% + + % This predicate takes the original goal and identifies the `shape' + % of the goal around the recursive and base cases. + % + % Note that the base case can contain a recursive call, as the + % transformation doesn't depend on what is in the base case. + % +:- pred identify_goal_type(pred_id::in, proc_id::in, hlds_goal::in, + instmap::in, top_level::out, list(hlds_goal)::out, instmap::out, + list(hlds_goal)::out, instmap::out) is semidet. + +identify_goal_type(PredId, ProcId, Goal, InitialInstMap, Type, + Base, BaseInstMap, Rec, RecInstMap) :- + Goal = hlds_goal(GoalExpr, _GoalInfo), + ( + GoalExpr = switch(_Var, _CanFail, Cases), + ( if + Cases = [case(_IdA, [], GoalA), case(_IdB, [], GoalB)], + goal_to_conj_list(GoalA, GoalAList), + goal_to_conj_list(GoalB, GoalBList) + then + ( if is_recursive_case(GoalAList, proc(PredId, ProcId)) then + Type = switch_rec_base, + Base = GoalBList, + Rec = GoalAList + else if is_recursive_case(GoalBList, proc(PredId, ProcId)) then + Type = switch_base_rec, + Base = GoalAList, + Rec = GoalBList + else + fail + ), + BaseInstMap = InitialInstMap, + RecInstMap = InitialInstMap + else + fail + ) + ; + GoalExpr = disj(Goals), + ( if + Goals = [GoalA, GoalB], + goal_to_conj_list(GoalA, GoalAList), + goal_to_conj_list(GoalB, GoalBList) + then + ( if is_recursive_case(GoalAList, proc(PredId, ProcId)) then + Type = disj_rec_base, + Base = GoalBList, + Rec = GoalAList + else if is_recursive_case(GoalBList, proc(PredId, ProcId)) then + Type = disj_base_rec, + Base = GoalAList, + Rec = GoalBList + else + fail + ), + BaseInstMap = InitialInstMap, + RecInstMap = InitialInstMap + else + fail + ) + ; + GoalExpr = if_then_else(_Vars, Cond, Then, Else), + Cond = hlds_goal(_CondGoalExpr, CondGoalInfo), + CondInstMapDelta = goal_info_get_instmap_delta(CondGoalInfo), + + goal_to_conj_list(Then, GoalAList), + goal_to_conj_list(Else, GoalBList), + ( if is_recursive_case(GoalAList, proc(PredId, ProcId)) then + Type = ite_rec_base, + Base = GoalBList, + Rec = GoalAList, + + BaseInstMap = InitialInstMap, + apply_instmap_delta(CondInstMapDelta, InitialInstMap, RecInstMap) + else if is_recursive_case(GoalBList, proc(PredId, ProcId)) then + Type = ite_base_rec, + Base = GoalAList, + Rec = GoalBList, + + RecInstMap = InitialInstMap, + apply_instmap_delta(CondInstMapDelta, InitialInstMap, BaseInstMap) + else + fail + ) + ). + + % is_recursive_case(Gs, Id) is true iff the list of goals, Gs, + % contains a call to the procedure specified by Id, where the call + % is located in a position that can be used by the transformation + % (i.e. not hidden in a compound goal). + % +:- pred is_recursive_case(list(hlds_goal)::in, pred_proc_id::in) is semidet. + +is_recursive_case(Goals, proc(PredId, ProcId)) :- + list.append(_Initial, [RecursiveCall | _Final], Goals), + RecursiveCall = hlds_goal(plain_call(PredId, ProcId, _, _, _, _), _). + +%---------------------------------------------------------------------------% +%---------------------------------------------------------------------------% + + % The store info is folded over the list of goals which + % represent the base and recursive case conjunctions. +:- type store_info + ---> store_info( + store_loc :: int, + % The location of the goal in the conjunction. + store_instmap :: instmap, + store_goals :: accu_goal_store + ). + + % Initialise the goal_store, which will hold the C_{a,b} goals. + % +:- func initialize_goal_store(list(hlds_goal), instmap, + list(hlds_goal), instmap) = accu_goal_store. + +initialize_goal_store(Rec, RecInstMap, Base, BaseInstMap) = C :- + goal_store_init(C0), + list.foldl3(accu_store(accu_rec), Rec, + 1, _, RecInstMap, _, C0, C1), + list.foldl3(accu_store(accu_base), Base, + 1, _, BaseInstMap, _, C1, C). + +:- pred accu_store(accu_case::in, hlds_goal::in, + int::in, int::out, instmap::in, instmap::out, + accu_goal_store::in, accu_goal_store::out) is det. + +accu_store(Case, Goal, !N, !InstMap, !GoalStore) :- + Id = accu_goal_id(Case, !.N), + goal_store_det_insert(Id, stored_goal(Goal, !.InstMap), !GoalStore), + + !:N = !.N + 1, + Goal = hlds_goal(_, GoalInfo), + InstMapDelta = goal_info_get_instmap_delta(GoalInfo), + apply_instmap_delta(InstMapDelta, !InstMap). + +%---------------------------------------------------------------------------% +%---------------------------------------------------------------------------% + + % Determine the k's which are recursive calls. + % Note that this doesn't find recursive calls which are `hidden' + % in compound goals, this is not a problem as currently we can't use + % these to do transformation. + % +:- pred identify_recursive_calls(pred_id::in, proc_id::in, + accu_goal_store::in, list(accu_goal_id)::out) is det. + +identify_recursive_calls(PredId, ProcId, GoalStore, Ids) :- + P = + ( pred(Key::out) is nondet :- + goal_store_member(GoalStore, Key, stored_goal(Goal, _InstMap)), + Key = accu_goal_id(accu_rec, _), + Goal = hlds_goal(plain_call(PredId, ProcId, _, _, _, _), _) + ), + solutions.solutions(P, Ids). + +%---------------------------------------------------------------------------% +%---------------------------------------------------------------------------% + + % Determine the variables which are members of the sets Out and Out', + % and initialize the substitutions between the two sets. + % + % This is done by identifing those variables whose instantiatedness change + % in the goals after the recursive call and are headvars. + % + % Note that we are only identifying the output variables which will need + % to be accumulated, as there may be other output variables which are + % produced prior to the recursive call. + % +:- pred identify_out_and_out_prime(module_info::in, vartypes::in, instmap::in, + accu_goal_id::in, list(hlds_goal)::in, + list(prog_var)::in, list(prog_var)::out, list(prog_var)::out, + accu_subst::out, accu_subst::out) is det. + +identify_out_and_out_prime(ModuleInfo, VarTypes, InitialInstMap, GoalId, + Rec, HeadVars, Out, OutPrime, HeadToCallSubst, CallToHeadSubst) :- + GoalId = accu_goal_id(_Case, K), + ( if + list.take(K, Rec, InitialGoals), + list.drop(K-1, Rec, FinalGoals), + FinalGoals = [hlds_goal(plain_call(_, _, Args, _, _, _), _) | Rest] + then + goal_list_instmap_delta(InitialGoals, InitInstMapDelta), + apply_instmap_delta( InitInstMapDelta, + InitialInstMap, InstMapBeforeRest), + + goal_list_instmap_delta(Rest, InstMapDelta), + apply_instmap_delta(InstMapDelta, InstMapBeforeRest, InstMapAfterRest), + + instmap_changed_vars(ModuleInfo, VarTypes, + InstMapBeforeRest, InstMapAfterRest, ChangedVars), + + assoc_list.from_corresponding_lists(HeadVars, Args, HeadArg0), + + Member = + ( pred(M::in) is semidet :- + M = HeadVar - _, + set_of_var.member(ChangedVars, HeadVar) + ), + list.filter(Member, HeadArg0, HeadArg), + list.map(fst, HeadArg, Out), + list.map(snd, HeadArg, OutPrime), + + map.from_assoc_list(HeadArg, HeadToCallSubst), + + list.map((pred(X-Y::in, Y-X::out) is det), HeadArg, ArgHead), + map.from_assoc_list(ArgHead, CallToHeadSubst) + else + unexpected($pred, "test failed") + ). + +%---------------------------------------------------------------------------% +%---------------------------------------------------------------------------% + + % For each goal after the recursive call, we place that goal + % into a set according to what properties that goal has. + % For the definition of what goes into each set, inspect the documentation + % for the functions named before, assoc, and so on. + % +:- type accu_sets + ---> accu_sets( + as_before :: set(accu_goal_id), + as_assoc :: set(accu_goal_id), + as_construct_assoc :: set(accu_goal_id), + as_construct :: set(accu_goal_id), + as_update :: set(accu_goal_id), + as_reject :: set(accu_goal_id) + ). + + % Stage 1 is responsible for identifying which goals are associative, + % which can be moved before the recursive call and so on. + % +:- pred accu_stage1(module_info::in, vartypes::in, bool::in, + maybe_opt_lcmc_accumulator::in, accu_goal_id::in, int::in, + accu_goal_store::in, accu_sets::out) is semidet. + +accu_stage1(ModuleInfo, VarTypes, FullyStrict, DoLCMC, GoalId, M, GoalStore, + Sets) :- + GoalId = accu_goal_id(Case, K), + NextGoalId = accu_goal_id(Case, K + 1), + accu_sets_init(Sets0), + accu_stage1_2(ModuleInfo, VarTypes, FullyStrict, NextGoalId, K, M, + GoalStore, Sets0, Sets1), + Sets1 = accu_sets(Before, Assoc, + ConstructAssoc, Construct, Update, Reject), + Sets = accu_sets(Before `set.union` set_upto(Case, K - 1), Assoc, + ConstructAssoc, Construct, Update, Reject), + + % Continue the transformation only if the set reject is empty and + % the set assoc or update contains something that needs to be moved + % before the recursive call. + set.is_empty(Reject), + ( + not set.is_empty(Assoc) + ; + not set.is_empty(Update) + ), + ( + DoLCMC = do_not_opt_lcmc_accumulator, + % If LCMC is not turned on, then there must be no construction + % unifications after the recursive call. + set.is_empty(Construct), + set.is_empty(ConstructAssoc) + ; + DoLCMC = opt_lcmc_accumulator + ). + + % For each goal after the recursive call decide which set + % the goal belongs to. + % +:- pred accu_stage1_2(module_info::in, vartypes::in, bool::in, + accu_goal_id::in, int::in, int::in, accu_goal_store::in, + accu_sets::in, accu_sets::out) is det. + +accu_stage1_2(ModuleInfo, VarTypes, FullyStrict, GoalId, K, M, GoalStore, + !Sets) :- + GoalId = accu_goal_id(Case, I), + NextGoalId = accu_goal_id(Case, I + 1), + ( if I > M then + true + else + ( if + accu_before(ModuleInfo, VarTypes, FullyStrict, GoalId, K, + GoalStore, !.Sets) + then + !Sets ^ as_before := set.insert(!.Sets ^ as_before, GoalId), + accu_stage1_2(ModuleInfo, VarTypes, FullyStrict, NextGoalId, K, M, + GoalStore, !Sets) + else if + accu_assoc(ModuleInfo, VarTypes, FullyStrict, GoalId, K, + GoalStore, !.Sets) + then + !Sets ^ as_assoc := set.insert(!.Sets ^ as_assoc, GoalId), + accu_stage1_2(ModuleInfo, VarTypes, FullyStrict, NextGoalId, K, M, + GoalStore, !Sets) + else if + accu_construct(ModuleInfo, VarTypes, FullyStrict, GoalId, K, + GoalStore, !.Sets) + then + !Sets ^ as_construct := set.insert(!.Sets ^ as_construct, GoalId), + accu_stage1_2(ModuleInfo, VarTypes, FullyStrict, NextGoalId, K, M, + GoalStore, !Sets) + else if + accu_construct_assoc(ModuleInfo, VarTypes, FullyStrict, GoalId, K, + GoalStore, !.Sets) + then + !Sets ^ as_construct_assoc := + set.insert(!.Sets ^ as_construct_assoc, GoalId), + accu_stage1_2(ModuleInfo, VarTypes, FullyStrict, NextGoalId, K, M, + GoalStore, !Sets) + else if + accu_update(ModuleInfo, VarTypes, FullyStrict, GoalId, K, + GoalStore, !.Sets) + then + !Sets ^ as_update := set.insert(!.Sets ^ as_update, GoalId), + accu_stage1_2(ModuleInfo, VarTypes, FullyStrict, NextGoalId, K, M, + GoalStore, !Sets) + else + !Sets ^ as_reject := set.insert(!.Sets ^ as_reject, GoalId) + ) + ). + +%---------------------------------------------------------------------------% + +:- pred accu_sets_init(accu_sets::out) is det. + +accu_sets_init(Sets) :- + set.init(EmptySet), + Before = EmptySet, + Assoc = EmptySet, + ConstructAssoc = EmptySet, + Construct = EmptySet, + Update = EmptySet, + Reject = EmptySet, + Sets = accu_sets(Before, Assoc, ConstructAssoc, Construct, Update, Reject). + + % set_upto(Case, K) returns the set + % {accu_goal_id(Case, 1) .. accu_goal_id(Case, K)}. + % +:- func set_upto(accu_case, int) = set(accu_goal_id). + +set_upto(Case, K) = Set :- + ( if K =< 0 then + set.init(Set) + else + Set0 = set_upto(Case, K - 1), + set.insert(accu_goal_id(Case, K), Set0, Set) + ). + +%---------------------------------------------------------------------------% + + % A goal is a member of the before set iff the goal only depends on goals + % which are before the recursive call or can be moved before the recursive + % call (member of the before set). + % +:- pred accu_before(module_info::in, vartypes::in, bool::in, + accu_goal_id::in, int::in, accu_goal_store::in, accu_sets::in) is semidet. + +accu_before(ModuleInfo, VarTypes, FullyStrict, GoalId, K, GoalStore, Sets) :- + GoalId = accu_goal_id(Case, _I), + Before = Sets ^ as_before, + goal_store_lookup(GoalStore, GoalId, stored_goal(LaterGoal, LaterInstMap)), + ( + member_lessthan_goalid(GoalStore, GoalId, LessThanGoalId, + stored_goal(EarlierGoal, EarlierInstMap)), + not can_reorder_goals_old(ModuleInfo, VarTypes, FullyStrict, + EarlierInstMap, EarlierGoal, LaterInstMap, LaterGoal) + ) + => + ( + set.member(LessThanGoalId, set_upto(Case, K - 1) `union` Before) + ). + + % A goal is a member of the assoc set iff the goal only depends on goals + % upto and including the recursive call and goals which can be moved + % before the recursive call (member of the before set) AND the goal + % is associative. + % +:- pred accu_assoc(module_info::in, vartypes::in, bool::in, + accu_goal_id::in, int::in, accu_goal_store::in, accu_sets::in) is semidet. + +accu_assoc(ModuleInfo, VarTypes, FullyStrict, GoalId, K, GoalStore, Sets) :- + GoalId = accu_goal_id(Case, _I), + Before = Sets ^ as_before, + goal_store_lookup(GoalStore, GoalId, stored_goal(LaterGoal, LaterInstMap)), + LaterGoal = hlds_goal(plain_call(PredId, _, Args, _, _, _), _), + accu_is_associative(ModuleInfo, PredId, Args, _), + ( + % XXX LessThanGoalId was _N - J, not N - J: it ignored the case. + % See the diff with the previous version. + member_lessthan_goalid(GoalStore, GoalId, LessThanGoalId, + stored_goal(EarlierGoal, EarlierInstMap)), + not can_reorder_goals_old(ModuleInfo, VarTypes, FullyStrict, + EarlierInstMap, EarlierGoal, LaterInstMap, LaterGoal) + ) + => + ( + set.member(LessThanGoalId, set_upto(Case, K) `union` Before) + ). + + % A goal is a member of the construct set iff the goal only depends + % on goals upto and including the recursive call and goals which + % can be moved before the recursive call (member of the before set) + % AND the goal is construction unification. + % +:- pred accu_construct(module_info::in, vartypes::in, bool::in, + accu_goal_id::in, int::in, accu_goal_store::in, accu_sets::in) is semidet. + +accu_construct(ModuleInfo, VarTypes, FullyStrict, GoalId, K, GoalStore, + Sets) :- + GoalId = accu_goal_id(Case, _I), + Before = Sets ^ as_before, + Construct = Sets ^ as_construct, + goal_store_lookup(GoalStore, GoalId, stored_goal(LaterGoal, LaterInstMap)), + LaterGoal = hlds_goal(unify(_, _, _, Unify, _), _GoalInfo), + Unify = construct(_, _, _, _, _, _, _), + ( + % XXX LessThanGoalId was _N - J, not N - J: it ignored the case. + % See the diff with the previous version. + member_lessthan_goalid(GoalStore, GoalId, LessThanGoalId, + stored_goal(EarlierGoal, EarlierInstMap)), + not can_reorder_goals_old(ModuleInfo, VarTypes, FullyStrict, + EarlierInstMap, EarlierGoal, LaterInstMap, LaterGoal) + ) + => + ( + set.member(LessThanGoalId, + set_upto(Case, K) `union` Before `union` Construct) + ). + + % A goal is a member of the construct_assoc set iff the goal depends only + % on goals upto and including the recursive call and goals which can be + % moved before the recursive call (member of the before set) and goals + % which are associative AND the goal is construction unification AND + % there is only one member of the assoc set which the construction + % unification depends on AND the construction unification can be expressed + % as a call to the member of the assoc set which the construction + % unification depends on. + % +:- pred accu_construct_assoc(module_info::in, vartypes::in, bool::in, + accu_goal_id::in, int::in, accu_goal_store::in, accu_sets::in) is semidet. + +accu_construct_assoc(ModuleInfo, VarTypes, FullyStrict, + GoalId, K, GoalStore, Sets) :- + GoalId = accu_goal_id(Case, _I), + Before = Sets ^ as_before, + Assoc = Sets ^ as_assoc, + ConstructAssoc = Sets ^ as_construct_assoc, + goal_store_lookup(GoalStore, GoalId, stored_goal(LaterGoal, LaterInstMap)), + LaterGoal = hlds_goal(unify(_, _, _, Unify, _), _GoalInfo), + Unify = construct(_, ConsId, _, _, _, _, _), + + goal_store_all_ancestors(GoalStore, GoalId, VarTypes, ModuleInfo, + FullyStrict, Ancestors), + + set.is_singleton(Assoc `intersect` Ancestors, AssocId), + goal_store_lookup(GoalStore, AssocId, + stored_goal(AssocGoal, _AssocInstMap)), + AssocGoal = hlds_goal(plain_call(PredId, _, _, _, _, _), _), + + is_associative_construction(ModuleInfo, PredId, ConsId), + ( + % XXX LessThanGoalId was _N - J, not N - J: it ignored the case. + % See the diff with the previous version. + member_lessthan_goalid(GoalStore, GoalId, LessThanGoalId, + stored_goal(EarlierGoal, EarlierInstMap)), + not can_reorder_goals_old(ModuleInfo, VarTypes, FullyStrict, + EarlierInstMap, EarlierGoal, LaterInstMap, LaterGoal) + ) + => + ( + set.member(LessThanGoalId, + set_upto(Case, K) `union` Before `union` Assoc + `union` ConstructAssoc) + ). + + % A goal is a member of the update set iff the goal only depends + % on goals upto and including the recursive call and goals which + % can be moved before the recursive call (member of the before set) + % AND the goal updates some state. + % +:- pred accu_update(module_info::in, vartypes::in, bool::in, + accu_goal_id::in, int::in, accu_goal_store::in, accu_sets::in) is semidet. + +accu_update(ModuleInfo, VarTypes, FullyStrict, GoalId, K, GoalStore, Sets) :- + GoalId = accu_goal_id(Case, _I), + Before = Sets ^ as_before, + goal_store_lookup(GoalStore, GoalId, stored_goal(LaterGoal, LaterInstMap)), + LaterGoal = hlds_goal(plain_call(PredId, _, Args, _, _, _), _), + accu_is_update(ModuleInfo, PredId, Args, _), + ( + % XXX LessThanGoalId was _N - J, not N - J: it ignored the case. + % See the diff with the previous version. + member_lessthan_goalid(GoalStore, GoalId, LessThanGoalId, + stored_goal(EarlierGoal, EarlierInstMap)), + not can_reorder_goals_old(ModuleInfo, VarTypes, FullyStrict, + EarlierInstMap, EarlierGoal, LaterInstMap, LaterGoal) + ) + => + ( + set.member(LessThanGoalId, set_upto(Case, K) `union` Before) + ). + + % member_lessthan_goalid(GS, IdA, IdB, GB) is true iff the goal_id, IdB, + % and its associated goal, GB, is a member of the goal_store, GS, + % and IdB is less than IdA. + % +:- pred member_lessthan_goalid(accu_goal_store::in, + accu_goal_id::in, accu_goal_id::out, stored_goal::out) is nondet. + +member_lessthan_goalid(GoalStore, GoalId, LessThanGoalId, LessThanGoal) :- + goal_store_member(GoalStore, LessThanGoalId, LessThanGoal), + GoalId = accu_goal_id(Case, I), + LessThanGoalId = accu_goal_id(Case, J), + J < I. + +%---------------------------------------------------------------------------% + +:- type accu_assoc + ---> accu_assoc( + set_of_progvar, % the associative input args + prog_var, % the corresponding output arg + bool % is the predicate commutative? + ). + + % If accu_is_associative is true, it returns the two arguments which are + % associative and the variable which depends on those two arguments, + % and an indicator of whether or not the predicate is commutative. + % +:- pred accu_is_associative(module_info::in, pred_id::in, list(prog_var)::in, + accu_assoc::out) is semidet. + +accu_is_associative(ModuleInfo, PredId, Args, Result) :- + module_info_pred_info(ModuleInfo, PredId, PredInfo), + pred_info_get_assertions(PredInfo, Assertions), + AssertionsList = set.to_sorted_list(Assertions), + associativity_assertion(ModuleInfo, AssertionsList, Args, + AssociativeVarsOutputVar), + ( if + commutativity_assertion(ModuleInfo, AssertionsList, Args, + _CommutativeVars) + then + IsCommutative = yes + else + IsCommutative = no + ), + AssociativeVarsOutputVar = + associative_vars_output_var(AssociativeVars, OutputVar), + Result = accu_assoc(AssociativeVars, OutputVar, IsCommutative). + + % Does there exist one (and only one) associativity assertion for the + % current predicate? + % The 'and only one condition' is required because we currently + % do not handle the case of predicates which have individual parts + % which are associative, because then we do not know which variable + % is descended from which. + % +:- pred associativity_assertion(module_info::in, list(assert_id)::in, + list(prog_var)::in, associative_vars_output_var::out) is semidet. + +associativity_assertion(ModuleInfo, [AssertId | AssertIds], Args0, + AssociativeVarsOutputVar) :- + ( if + assertion.is_associativity_assertion(ModuleInfo, AssertId, + Args0, AssociativeVarsOutputVarPrime) + then + AssociativeVarsOutputVar = AssociativeVarsOutputVarPrime, + not associativity_assertion(ModuleInfo, AssertIds, Args0, _) + else + associativity_assertion(ModuleInfo, AssertIds, Args0, + AssociativeVarsOutputVar) + ). + + % Does there exist one (and only one) commutativity assertion for the + % current predicate? + % The 'and only one condition' is required because we currently + % do not handle the case of predicates which have individual + % parts which are commutative, because then we do not know which variable + % is descended from which. + % +:- pred commutativity_assertion(module_info::in,list(assert_id)::in, + list(prog_var)::in, set_of_progvar::out) is semidet. + +commutativity_assertion(ModuleInfo, [AssertId | AssertIds], Args0, + CommutativeVars) :- + ( if + assertion.is_commutativity_assertion(ModuleInfo, AssertId, + Args0, CommutativeVarsPrime) + then + CommutativeVars = CommutativeVarsPrime, + not commutativity_assertion(ModuleInfo, AssertIds, Args0, _) + else + commutativity_assertion(ModuleInfo, AssertIds, Args0, + CommutativeVars) + ). + +%---------------------------------------------------------------------------% + + % Does the current predicate update some state? + % +:- pred accu_is_update(module_info::in, pred_id::in, list(prog_var)::in, + state_update_vars::out) is semidet. + +accu_is_update(ModuleInfo, PredId, Args, ResultStateVars) :- + module_info_pred_info(ModuleInfo, PredId, PredInfo), + pred_info_get_assertions(PredInfo, Assertions), + list.filter_map( + ( pred(AssertId::in, StateVars::out) is semidet :- + assertion.is_update_assertion(ModuleInfo, AssertId, + PredId, Args, StateVars) + ), + set.to_sorted_list(Assertions), Result), + % XXX Maybe we should just match on the first result, + % just in case there are duplicate promises. + Result = [ResultStateVars]. + +%---------------------------------------------------------------------------% + + % Can the construction unification be expressed as a call to the + % specified predicate. + % +:- pred is_associative_construction(module_info::in, pred_id::in, cons_id::in) + is semidet. + +is_associative_construction(ModuleInfo, PredId, ConsId) :- + module_info_pred_info(ModuleInfo, PredId, PredInfo), + pred_info_get_assertions(PredInfo, Assertions), + list.filter( + ( pred(AssertId::in) is semidet :- + assertion.is_construction_equivalence_assertion(ModuleInfo, + AssertId, ConsId, PredId) + ), + set.to_sorted_list(Assertions), Result), + Result = [_ | _]. + +%---------------------------------------------------------------------------% +%---------------------------------------------------------------------------% + +:- type accu_substs + ---> accu_substs( + acc_var_subst :: accu_subst, + rec_call_subst :: accu_subst, + assoc_call_subst :: accu_subst, + update_subst :: accu_subst + ). + +:- type accu_base + ---> accu_base( + % goals which initialize update + init_update :: set(accu_goal_id), + + % goals which initialize assoc + init_assoc :: set(accu_goal_id), + + % other goals + other :: set(accu_goal_id) + ). + + % Stage 2 is responsible for identifying the substitutions which + % are needed to mimic the unfold/fold process that was used as + % the justification of the algorithm in the paper. + % It is also responsible for ensuring that the reordering of arguments + % doesn't worsen the big-O complexity of the procedure. + % It also divides the base case into goals that initialize the + % variables used by the update goals, and those used by the assoc + % goals and then all the rest. + % +:- pred accu_stage2(module_info::in, proc_info::in, + accu_goal_id::in, accu_goal_store::in, accu_sets::in, + list(prog_var)::in, list(prog_var)::in, prog_varset::out, vartypes::out, + list(prog_var)::out, accu_base::out, list(pair(prog_var))::out, + accu_substs::out, accu_goal_store::out, list(accu_warning)::out) + is semidet. + +accu_stage2(ModuleInfo, ProcInfo0, GoalId, GoalStore, Sets, OutPrime, Out, + !:VarSet, !:VarTypes, Accs, BaseCase, BasePairs, !:Substs, + CS, Warnings) :- + Sets = accu_sets(Before0, Assoc, ConstructAssoc, Construct, Update, _), + GoalId = accu_goal_id(Case, K), + Before = Before0 `union` set_upto(Case, K-1), + + % Note Update set is not placed in the after set, as the after set is used + % to determine the variables that need to be accumulated for the + % associative calls. + After = Assoc `union` ConstructAssoc `union` Construct, + + P = + ( pred(Id::in, Set0::in, Set::out) is det :- + goal_store_lookup(GoalStore, Id, stored_goal(Goal, _InstMap)), + Goal = hlds_goal(_GoalExpr, GoalInfo), + NonLocals = goal_info_get_nonlocals(GoalInfo), + set_of_var.union(NonLocals, Set0, Set) + ), + list.foldl(P, set.to_sorted_list(Before), + set_of_var.init, BeforeNonLocals), + list.foldl(P, set.to_sorted_list(After), + set_of_var.init, AfterNonLocals), + InitAccs = set_of_var.intersect(BeforeNonLocals, AfterNonLocals), + + proc_info_get_varset(ProcInfo0, !:VarSet), + proc_info_get_vartypes(ProcInfo0, !:VarTypes), + + accu_substs_init(set_of_var.to_sorted_list(InitAccs), !VarSet, !VarTypes, + !:Substs), + + set_of_var.list_to_set(OutPrime, OutPrimeSet), + accu_process_assoc_set(ModuleInfo, GoalStore, set.to_sorted_list(Assoc), + OutPrimeSet, !Substs, !VarSet, !VarTypes, CS, Warnings), + + accu_process_update_set(ModuleInfo, GoalStore, set.to_sorted_list(Update), + OutPrimeSet, !Substs, !VarSet, !VarTypes, UpdateOut, UpdateAccOut, + BasePairs), + + Accs = set_of_var.to_sorted_list(InitAccs) ++ UpdateAccOut, + + accu_divide_base_case(ModuleInfo, !.VarTypes, GoalStore, UpdateOut, Out, + UpdateBase, AssocBase, OtherBase), + + BaseCase = accu_base(UpdateBase, AssocBase, OtherBase). + +%---------------------------------------------------------------------------% + +:- pred accu_substs_init(list(prog_var)::in, prog_varset::in, prog_varset::out, + vartypes::in, vartypes::out, accu_substs::out) is det. + +accu_substs_init(InitAccs, !VarSet, !VarTypes, Substs) :- + map.init(Subst), + acc_var_subst_init(InitAccs, !VarSet, !VarTypes, AccVarSubst), + RecCallSubst = Subst, + AssocCallSubst = Subst, + UpdateSubst = Subst, + Substs = accu_substs(AccVarSubst, RecCallSubst, AssocCallSubst, + UpdateSubst). + + % Initialise the acc_var_subst to be from Var to A_Var where Var is a + % member of InitAccs and A_Var is a fresh variable of the same type of Var. + % +:- pred acc_var_subst_init(list(prog_var)::in, + prog_varset::in, prog_varset::out, vartypes::in, vartypes::out, + accu_subst::out) is det. + +acc_var_subst_init([], !VarSet, !VarTypes, map.init). +acc_var_subst_init([Var | Vars], !VarSet, !VarTypes, Subst) :- + create_new_var(Var, "A_", AccVar, !VarSet, !VarTypes), + acc_var_subst_init(Vars, !VarSet, !VarTypes, Subst0), + map.det_insert(Var, AccVar, Subst0, Subst). + + % Create a fresh variable which is the same type as the old variable + % and has the same name except that it begins with the prefix. + % +:- pred create_new_var(prog_var::in, string::in, prog_var::out, + prog_varset::in, prog_varset::out, vartypes::in, vartypes::out) is det. + +create_new_var(OldVar, Prefix, NewVar, !VarSet, !VarTypes) :- + varset.lookup_name(!.VarSet, OldVar, OldName), + string.append(Prefix, OldName, NewName), + varset.new_named_var(NewName, NewVar, !VarSet), + lookup_var_type(!.VarTypes, OldVar, Type), + add_var_type(NewVar, Type, !VarTypes). + +%---------------------------------------------------------------------------% + + % For each member of the assoc set determine the substitutions needed, + % and also check the efficiency of the procedure isn't worsened + % by reordering the arguments to a call. + % +:- pred accu_process_assoc_set(module_info::in, accu_goal_store::in, + list(accu_goal_id)::in, set_of_progvar::in, + accu_substs::in, accu_substs::out, + prog_varset::in, prog_varset::out, vartypes::in, vartypes::out, + accu_goal_store::out, list(accu_warning)::out) is semidet. + +accu_process_assoc_set(_ModuleInfo, _GS, [], _OutPrime, !Substs, + !VarSet, !VarTypes, CS, []) :- + goal_store_init(CS). +accu_process_assoc_set(ModuleInfo, GS, [Id | Ids], OutPrime, !Substs, + !VarSet, !VarTypes, CS, Warnings) :- + !.Substs = accu_substs(AccVarSubst, RecCallSubst0, AssocCallSubst0, + UpdateSubst), + + lookup_call(GS, Id, stored_goal(Goal, InstMap)), + + Goal = hlds_goal(plain_call(PredId, _, Args, _, _, _), GoalInfo), + accu_is_associative(ModuleInfo, PredId, Args, AssocInfo), + AssocInfo = accu_assoc(Vars, AssocOutput, IsCommutative), + OutPrimeVars = set_of_var.intersect(Vars, OutPrime), + set_of_var.is_singleton(OutPrimeVars, DuringAssocVar), + set_of_var.is_singleton(set_of_var.difference(Vars, OutPrimeVars), + BeforeAssocVar), + + map.lookup(AccVarSubst, BeforeAssocVar, AccVar), + create_new_var(BeforeAssocVar, "NewAcc_", NewAcc, !VarSet, !VarTypes), + + map.det_insert(DuringAssocVar, AccVar, AssocCallSubst0, AssocCallSubst1), + map.det_insert(AssocOutput, NewAcc, AssocCallSubst1, AssocCallSubst), + map.det_insert(DuringAssocVar, AssocOutput, RecCallSubst0, RecCallSubst1), + map.det_insert(BeforeAssocVar, NewAcc, RecCallSubst1, RecCallSubst), + + !:Substs = accu_substs(AccVarSubst, RecCallSubst, AssocCallSubst, + UpdateSubst), + + % ONLY swap the order of the variables if the goal is + % associative and not commutative. + ( + IsCommutative = yes, + CSGoal = stored_goal(Goal, InstMap), + CurWarnings = [] + ; + IsCommutative = no, + + % Ensure that the reordering doesn't cause a efficiency problem. + module_info_pred_info(ModuleInfo, PredId, PredInfo), + ModuleName = pred_info_module(PredInfo), + PredName = pred_info_name(PredInfo), + Arity = pred_info_orig_arity(PredInfo), + ( if accu_has_heuristic(ModuleName, PredName, Arity) then + % Only do the transformation if the accumulator variable is + % *not* in a position where it will control the running time + % of the predicate. + accu_heuristic(ModuleName, PredName, Arity, Args, + PossibleDuringAssocVars), + set_of_var.member(PossibleDuringAssocVars, DuringAssocVar), + CurWarnings = [] + else + ProgContext = goal_info_get_context(GoalInfo), + CurWarnings = [accu_warn(ProgContext, PredId, BeforeAssocVar, + DuringAssocVar)] + ), + % Swap the arguments. + [A, B] = set_of_var.to_sorted_list(Vars), + map.from_assoc_list([A - B, B - A], Subst), + rename_some_vars_in_goal(Subst, Goal, SwappedGoal), + CSGoal = stored_goal(SwappedGoal, InstMap) + ), + + accu_process_assoc_set(ModuleInfo, GS, Ids, OutPrime, !Substs, + !VarSet, !VarTypes, CS0, Warnings0), + goal_store_det_insert(Id, CSGoal, CS0, CS), + Warnings = Warnings0 ++ CurWarnings. + +:- pred accu_has_heuristic(module_name::in, string::in, arity::in) is semidet. + +accu_has_heuristic(unqualified("list"), "append", 3). + + % heuristic returns the set of which head variables are important + % in the running time of the predicate. + % +:- pred accu_heuristic(module_name::in, string::in, arity::in, + list(prog_var)::in, set_of_progvar::out) is semidet. + +accu_heuristic(unqualified("list"), "append", 3, [_Typeinfo, A, _B, _C], + Set) :- + set_of_var.make_singleton(A, Set). + +%---------------------------------------------------------------------------% + + % For each member of the update set determine the substitutions needed + % (creating the accumulator variables when needed). + % Also associate with each Output variable which accumulator variable + % to get the result from. + % +:- pred accu_process_update_set(module_info::in, accu_goal_store::in, + list(accu_goal_id)::in, set_of_progvar::in, + accu_substs::in, accu_substs::out, + prog_varset::in, prog_varset::out, vartypes::in, vartypes::out, + list(prog_var)::out, list(prog_var)::out, list(pair(prog_var))::out) + is semidet. + +accu_process_update_set(_ModuleInfo, _GS, [], _OutPrime, !Substs, + !VarSet, !VarTypes, [], [], []). +accu_process_update_set(ModuleInfo, GS, [Id | Ids], OutPrime, !Substs, + !VarSet, !VarTypes, StateOutputVars, Accs, BasePairs) :- + !.Substs = accu_substs(AccVarSubst0, RecCallSubst0, AssocCallSubst, + UpdateSubst0), + lookup_call(GS, Id, stored_goal(Goal, _InstMap)), + + Goal = hlds_goal(plain_call(PredId, _, Args, _, _, _), _GoalInfo), + accu_is_update(ModuleInfo, PredId, Args, StateVars), + StateVars = state_update_vars(StateVarA, StateVarB), + + ( if set_of_var.member(OutPrime, StateVarA) then + StateInputVar = StateVarA, + StateOutputVar = StateVarB + else + StateInputVar = StateVarB, + StateOutputVar = StateVarA + ), + + create_new_var(StateInputVar, "Acc_", Acc0, !VarSet, !VarTypes), + create_new_var(StateOutputVar, "Acc_", Acc, !VarSet, !VarTypes), + + map.det_insert(StateInputVar, Acc0, UpdateSubst0, UpdateSubst1), + map.det_insert(StateOutputVar, Acc, UpdateSubst1, UpdateSubst), + map.det_insert(StateInputVar, StateOutputVar, RecCallSubst0, RecCallSubst), + map.det_insert(Acc, Acc0, AccVarSubst0, AccVarSubst), + !:Substs = accu_substs(AccVarSubst, RecCallSubst, AssocCallSubst, + UpdateSubst), + + accu_process_update_set(ModuleInfo, GS, Ids, OutPrime, !Substs, + !VarSet, !VarTypes, StateOutputVars0, Accs0, BasePairs0), + + % Rather then concatenating to start of the list we concatenate to the end + % of the list. This allows the accumulator introduction to be applied + % as the heuristic will succeed (remember after transforming the two + % input variables will have their order swapped, so they must be in the + % inefficient order to start with) + + StateOutputVars = StateOutputVars0 ++ [StateOutputVar], + Accs = Accs0 ++ [Acc], + BasePairs = BasePairs0 ++ [StateOutputVar - Acc0]. + +%---------------------------------------------------------------------------% + + % divide_base_case(UpdateOut, Out, U, A, O) is true iff given the output + % variables which are instantiated by update goals, UpdateOut, and all + % the variables that need to be accumulated, Out, divide the base case up + % into three sets, those base case goals which initialize the variables + % used by update calls, U, those which initialize variables used by + % assoc calls, A, and the rest of the goals, O. Note that the sets + % are not necessarily disjoint, as the result of a goal may be used + % to initialize a variable in both U and A, so both U and A will contain + % the same goal_id. + % +:- pred accu_divide_base_case(module_info::in, vartypes::in, + accu_goal_store::in, list(prog_var)::in, list(prog_var)::in, + set(accu_goal_id)::out, set(accu_goal_id)::out, set(accu_goal_id)::out) + is det. + +accu_divide_base_case(ModuleInfo, VarTypes, C, UpdateOut, Out, + UpdateBase, AssocBase, OtherBase) :- + list.delete_elems(Out, UpdateOut, AssocOut), + + list.map(accu_related(ModuleInfo, VarTypes, C), UpdateOut, UpdateBaseList), + list.map(accu_related(ModuleInfo, VarTypes, C), AssocOut, AssocBaseList), + UpdateBase = set.power_union(set.list_to_set(UpdateBaseList)), + AssocBase = set.power_union(set.list_to_set(AssocBaseList)), + + Set = base_case_ids_set(C) `difference` (UpdateBase `union` AssocBase), + set.to_sorted_list(Set, List), + + list.map( + ( pred(GoalId::in, Ancestors::out) is det :- + goal_store_all_ancestors(C, GoalId, VarTypes, + ModuleInfo, no, Ancestors) + ), List, OtherBaseList), + + OtherBase = set.list_to_set(List) `union` + (base_case_ids_set(C) `intersect` + set.power_union(set.list_to_set(OtherBaseList))). + + % accu_related(ModuleInfo, VarTypes, GoalStore, Var, Related): + % + % From GoalStore, return all the goal_ids, Related, which are needed + % to initialize Var. + % +:- pred accu_related(module_info::in, vartypes::in, accu_goal_store::in, + prog_var::in, set(accu_goal_id)::out) is det. + +accu_related(ModuleInfo, VarTypes, GoalStore, Var, Related) :- + solutions.solutions( + ( pred(Key::out) is nondet :- + goal_store_member(GoalStore, Key, stored_goal(Goal, InstMap0)), + Key = accu_goal_id(accu_base, _), + Goal = hlds_goal(_GoalExpr, GoalInfo), + InstMapDelta = goal_info_get_instmap_delta(GoalInfo), + apply_instmap_delta(InstMapDelta, InstMap0, InstMap), + instmap_changed_vars(ModuleInfo, VarTypes, + InstMap0, InstMap, ChangedVars), + set_of_var.is_singleton(ChangedVars, Var) + ), Ids), + ( + Ids = [], + unexpected($pred, "no Id") + ; + Ids = [Id], + goal_store_all_ancestors(GoalStore, Id, VarTypes, ModuleInfo, no, + Ancestors), + list.filter((pred(accu_goal_id(accu_base, _)::in) is semidet), + set.to_sorted_list(set.insert(Ancestors, Id)), RelatedList), + Related = set.list_to_set(RelatedList) + ; + Ids = [_, _ | _], + unexpected($pred, "more than one Id") + ). + +%---------------------------------------------------------------------------% + +:- inst stored_goal_plain_call for goal_store.stored_goal/0 + ---> stored_goal(goal_plain_call, ground). + + % Do a goal_store_lookup where the result is known to be a call. + % +:- pred lookup_call(accu_goal_store::in, accu_goal_id::in, + stored_goal::out(stored_goal_plain_call)) is det. + +lookup_call(GoalStore, Id, stored_goal(Call, InstMap)) :- + goal_store_lookup(GoalStore, Id, stored_goal(Goal, InstMap)), + ( if + Goal = hlds_goal(GoalExpr, GoalInfo), + GoalExpr = plain_call(_, _, _, _, _, _) + then + Call = hlds_goal(GoalExpr, GoalInfo) + else + unexpected($pred, "not a call") + ). + +%---------------------------------------------------------------------------% +%---------------------------------------------------------------------------% + + % accu_stage3 creates the accumulator version of the predicate using + % the substitutions determined in stage2. It also redefines the + % original procedure to call the accumulator version of the procedure. + % +:- pred accu_stage3(accu_goal_id::in, list(prog_var)::in, prog_varset::in, + vartypes::in, accu_goal_store::in, accu_goal_store::in, + accu_substs::in, accu_subst::in, accu_subst::in, + accu_base::in, list(pair(prog_var))::in, accu_sets::in, + list(prog_var)::in, top_level::in, pred_id::in, pred_info::in, + proc_info::in, proc_info::out, module_info::in, module_info::out) is det. + +accu_stage3(RecCallId, Accs, VarSet, VarTypes, C, CS, Substs, + HeadToCallSubst, CallToHeadSubst, BaseCase, BasePairs, Sets, Out, + TopLevel, OrigPredId, OrigPredInfo, !OrigProcInfo, !ModuleInfo) :- + acc_proc_info(Accs, VarSet, VarTypes, Substs, !.OrigProcInfo, + AccTypes, AccProcInfo), + acc_pred_info(AccTypes, Out, AccProcInfo, OrigPredId, OrigPredInfo, + AccProcId, AccPredInfo), + AccName = unqualified(pred_info_name(AccPredInfo)), + + module_info_get_predicate_table(!.ModuleInfo, PredTable0), + predicate_table_insert(AccPredInfo, AccPredId, PredTable0, PredTable), + module_info_set_predicate_table(PredTable, !ModuleInfo), + accu_create_goal(RecCallId, Accs, AccPredId, AccProcId, AccName, Substs, + HeadToCallSubst, CallToHeadSubst, BaseCase, BasePairs, Sets, C, CS, + OrigBaseGoal, OrigRecGoal, AccBaseGoal, AccRecGoal), + + proc_info_get_goal(!.OrigProcInfo, OrigGoal0), + accu_top_level(TopLevel, OrigGoal0, OrigBaseGoal, OrigRecGoal, + AccBaseGoal, AccRecGoal, OrigGoal, AccGoal), + + proc_info_set_goal(OrigGoal, !OrigProcInfo), + proc_info_set_varset(VarSet, !OrigProcInfo), + proc_info_set_vartypes(VarTypes, !OrigProcInfo), + + requantify_proc_general(ordinary_nonlocals_no_lambda, !OrigProcInfo), + update_accumulator_pred(AccPredId, AccProcId, AccGoal, !ModuleInfo). + +%---------------------------------------------------------------------------% + + % Construct a proc_info for the introduced predicate. + % +:- pred acc_proc_info(list(prog_var)::in, prog_varset::in, vartypes::in, + accu_substs::in, proc_info::in, list(mer_type)::out, proc_info::out) + is det. + +acc_proc_info(Accs0, VarSet, VarTypes, Substs, OrigProcInfo, + AccTypes, AccProcInfo) :- + % ProcInfo Stuff that must change. + proc_info_get_headvars(OrigProcInfo, HeadVars0), + proc_info_get_argmodes(OrigProcInfo, HeadModes0), + + proc_info_get_inst_varset(OrigProcInfo, InstVarSet), + proc_info_get_inferred_determinism(OrigProcInfo, Detism), + proc_info_get_goal(OrigProcInfo, Goal), + proc_info_get_context(OrigProcInfo, Context), + proc_info_get_rtti_varmaps(OrigProcInfo, RttiVarMaps), + proc_info_get_is_address_taken(OrigProcInfo, IsAddressTaken), + proc_info_get_has_parallel_conj(OrigProcInfo, HasParallelConj), + proc_info_get_var_name_remap(OrigProcInfo, VarNameRemap), + + Substs = accu_substs(AccVarSubst, _RecCallSubst, _AssocCallSubst, + _UpdateSubst), + list.map(map.lookup(AccVarSubst), Accs0, Accs), + + % We place the extra accumulator variables at the start, because placing + % them at the end breaks the convention that the last variable of a + % function is the output variable. + HeadVars = Accs ++ HeadVars0, + + % XXX we don't want to use the inst of the var as it can be more specific + % than it should be. ie int_const(1) when it should be any integer. + % However this will no longer handle partially instantiated data + % structures. + Inst = ground(shared, none_or_default_func), + inst_lists_to_mode_list([Inst], [Inst], Mode), + list.duplicate(list.length(Accs), list.det_head(Mode), AccModes), + HeadModes = AccModes ++ HeadModes0, + + lookup_var_types(VarTypes, Accs, AccTypes), + + SeqNum = item_no_seq_num, + proc_info_create(Context, SeqNum, VarSet, VarTypes, HeadVars, + InstVarSet, HeadModes, detism_decl_none, Detism, Goal, RttiVarMaps, + IsAddressTaken, HasParallelConj, VarNameRemap, AccProcInfo). + +%---------------------------------------------------------------------------% + + % Construct the pred_info for the introduced predicate. + % +:- pred acc_pred_info(list(mer_type)::in, list(prog_var)::in, proc_info::in, + pred_id::in, pred_info::in, proc_id::out, pred_info::out) is det. + +acc_pred_info(NewTypes, OutVars, NewProcInfo, OrigPredId, OrigPredInfo, + NewProcId, NewPredInfo) :- + % PredInfo stuff that must change. + pred_info_get_arg_types(OrigPredInfo, TypeVarSet, ExistQVars, Types0), + + ModuleName = pred_info_module(OrigPredInfo), + Name = pred_info_name(OrigPredInfo), + PredOrFunc = pred_info_is_pred_or_func(OrigPredInfo), + pred_info_get_context(OrigPredInfo, PredContext), + pred_info_get_markers(OrigPredInfo, Markers), + pred_info_get_class_context(OrigPredInfo, ClassContext), + pred_info_get_origin(OrigPredInfo, OldOrigin), + pred_info_get_var_name_remap(OrigPredInfo, VarNameRemap), + + set.init(Assertions), + + proc_info_get_context(NewProcInfo, Context), + term.context_line(Context, Line), + Counter = 0, + + Types = NewTypes ++ Types0, + + make_pred_name_with_context(ModuleName, "AccFrom", PredOrFunc, Name, + Line, Counter, SymName), + + OutVarNums = list.map(term.var_to_int, OutVars), + Origin = origin_transformed(transform_accumulator(OutVarNums), + OldOrigin, OrigPredId), + GoalType = goal_not_for_promise(np_goal_type_none), + pred_info_create(ModuleName, SymName, PredOrFunc, PredContext, Origin, + pred_status(status_local), Markers, Types, TypeVarSet, + ExistQVars, ClassContext, Assertions, VarNameRemap, GoalType, + NewProcInfo, NewProcId, NewPredInfo). + +%---------------------------------------------------------------------------% + + % create_goal creates the new base and recursive case of the + % original procedure (OrigBaseGoal and OrigRecGoal) and the base + % and recursive cases of accumulator version (AccBaseGoal and + % AccRecGoal). + % +:- pred accu_create_goal(accu_goal_id::in, list(prog_var)::in, + pred_id::in, proc_id::in, sym_name::in, accu_substs::in, + accu_subst::in, accu_subst::in, accu_base::in, + list(pair(prog_var))::in, accu_sets::in, + accu_goal_store::in, accu_goal_store::in, + hlds_goal::out, hlds_goal::out, hlds_goal::out, hlds_goal::out) is det. + +accu_create_goal(RecCallId, Accs, AccPredId, AccProcId, AccName, Substs, + HeadToCallSubst, CallToHeadSubst, BaseIds, BasePairs, + Sets, C, CS, OrigBaseGoal, OrigRecGoal, AccBaseGoal, AccRecGoal) :- + lookup_call(C, RecCallId, stored_goal(OrigCall, _InstMap)), + Call = create_acc_call(OrigCall, Accs, AccPredId, AccProcId, AccName), + create_orig_goal(Call, Substs, HeadToCallSubst, CallToHeadSubst, + BaseIds, Sets, C, OrigBaseGoal, OrigRecGoal), + create_acc_goal(Call, Substs, HeadToCallSubst, BaseIds, BasePairs, + Sets, C, CS, AccBaseGoal, AccRecGoal). + + % create_acc_call takes the original call and generates a call to the + % accumulator version of the call, which can have the substitutions + % applied to it easily. + % +:- func create_acc_call(hlds_goal::in(goal_plain_call), list(prog_var)::in, + pred_id::in, proc_id::in, sym_name::in) = (hlds_goal::out(goal_plain_call)) + is det. + +create_acc_call(OrigCall, Accs, AccPredId, AccProcId, AccName) = Call :- + OrigCall = hlds_goal(OrigCallExpr, GoalInfo), + OrigCallExpr = plain_call(_PredId, _ProcId, Args, Builtin, Context, _Name), + CallExpr = plain_call(AccPredId, AccProcId, Accs ++ Args, Builtin, + Context, AccName), + Call = hlds_goal(CallExpr, GoalInfo). + + % Create the goals which are to replace the original predicate. + % +:- pred create_orig_goal(hlds_goal::in, accu_substs::in, + accu_subst::in, accu_subst::in, accu_base::in, accu_sets::in, + accu_goal_store::in, hlds_goal::out, hlds_goal::out) is det. + +create_orig_goal(Call, Substs, HeadToCallSubst, CallToHeadSubst, + BaseIds, Sets, C, OrigBaseGoal, OrigRecGoal) :- + Substs = accu_substs(_AccVarSubst, _RecCallSubst, _AssocCallSubst, + UpdateSubst), + + BaseIds = accu_base(UpdateBase, _AssocBase, _OtherBase), + Before = Sets ^ as_before, + Update = Sets ^ as_update, + + U = create_new_orig_recursive_goals(UpdateBase, Update, + HeadToCallSubst, UpdateSubst, C), + + rename_some_vars_in_goal(CallToHeadSubst, Call, BaseCall), + Cbefore = accu_goal_list(set.to_sorted_list(Before), C), + Uupdate = accu_goal_list(set.to_sorted_list(UpdateBase) ++ + set.to_sorted_list(Update), U), + Cbase = accu_goal_list(base_case_ids(C), C), + calculate_goal_info(conj(plain_conj, Cbefore ++ Uupdate ++ [BaseCall]), + OrigRecGoal), + calculate_goal_info(conj(plain_conj, Cbase), OrigBaseGoal). + + % Create the goals which are to go in the new accumulator version + % of the predicate. + % +:- pred create_acc_goal(hlds_goal::in, accu_substs::in, accu_subst::in, + accu_base::in, list(pair(prog_var))::in, accu_sets::in, + accu_goal_store::in, accu_goal_store::in, + hlds_goal::out, hlds_goal::out) is det. + +create_acc_goal(Call, Substs, HeadToCallSubst, BaseIds, BasePairs, Sets, + C, CS, AccBaseGoal, AccRecGoal) :- + Substs = accu_substs(AccVarSubst, RecCallSubst, AssocCallSubst, + UpdateSubst), + + BaseIds = accu_base(_UpdateBase, AssocBase, OtherBase), + Sets = accu_sets(Before, Assoc, ConstructAssoc, Construct, Update, + _Reject), + + rename_some_vars_in_goal(RecCallSubst, Call, RecCall), + + Cbefore = accu_goal_list(set.to_sorted_list(Before), C), + + % Create the goals which will be used in the new recursive case. + R = create_new_recursive_goals(Assoc, Construct `union` ConstructAssoc, + Update, AssocCallSubst, AccVarSubst, UpdateSubst, C, CS), + + Rassoc = accu_goal_list(set.to_sorted_list(Assoc), R), + Rupdate = accu_goal_list(set.to_sorted_list(Update), R), + Rconstruct = accu_goal_list(set.to_sorted_list(Construct `union` + ConstructAssoc), R), + + % Create the goals which will be used in the new base case. + B = create_new_base_goals(Assoc `union` Construct `union` + ConstructAssoc, C, AccVarSubst, HeadToCallSubst), + Bafter = set.to_sorted_list(Assoc `union` + Construct `union` ConstructAssoc), + + BaseCase = accu_goal_list(set.to_sorted_list(AssocBase `union` OtherBase) + ++ Bafter, B), + + list.map(acc_unification, BasePairs, UpdateBase), + + calculate_goal_info(conj(plain_conj, Cbefore ++ Rassoc ++ Rupdate + ++ [RecCall] ++ Rconstruct), AccRecGoal), + calculate_goal_info(conj(plain_conj, UpdateBase ++ BaseCase), AccBaseGoal). + + % Create the U set of goals (those that will be used in the original + % recursive case) by renaming all the goals which are used to initialize + % the update state variable using the head_to_call followed by the + % update_subst, and rename all the update goals using the update_subst. + % +:- func create_new_orig_recursive_goals(set(accu_goal_id), set(accu_goal_id), + accu_subst, accu_subst, accu_goal_store) = accu_goal_store. + +create_new_orig_recursive_goals(UpdateBase, Update, HeadToCallSubst, + UpdateSubst, C) + = accu_rename(set.to_sorted_list(Update), UpdateSubst, C, Ubase) :- + Ubase = accu_rename(set.to_sorted_list(UpdateBase), + chain_subst(HeadToCallSubst, UpdateSubst), C, goal_store_init). + + % Create the R set of goals (those that will be used in the new + % recursive case) by renaming all the members of assoc in CS + % using assoc_call_subst and all the members of (construct U + % construct_assoc) in C with acc_var_subst. + % +:- func create_new_recursive_goals(set(accu_goal_id), set(accu_goal_id), + set(accu_goal_id), accu_subst, accu_subst, accu_subst, + accu_goal_store, accu_goal_store) = accu_goal_store. + +create_new_recursive_goals(Assoc, Constructs, Update, + AssocCallSubst, AccVarSubst, UpdateSubst, C, CS) + = accu_rename(set.to_sorted_list(Constructs), AccVarSubst, C, RBase) :- + RBase0 = accu_rename(set.to_sorted_list(Assoc), AssocCallSubst, CS, + goal_store_init), + RBase = accu_rename(set.to_sorted_list(Update), UpdateSubst, C, RBase0). + + % Create the B set of goals (those that will be used in the new base case) + % by renaming all the base case goals of C with head_to_call and all the + % members of (assoc U construct U construct_assoc) of C with acc_var_subst. + % +:- func create_new_base_goals(set(accu_goal_id), accu_goal_store, + accu_subst, accu_subst) = accu_goal_store. + +create_new_base_goals(Ids, C, AccVarSubst, HeadToCallSubst) + = accu_rename(set.to_sorted_list(Ids), AccVarSubst, C, Bbase) :- + Bbase = accu_rename(base_case_ids(C), HeadToCallSubst, C, goal_store_init). + + % acc_unification(O-A, G): + % + % is true if G represents the assignment unification Out = Acc. + % +:- pred acc_unification(pair(prog_var)::in, hlds_goal::out) is det. + +acc_unification(Out - Acc, Goal) :- + UnifyMode = unify_modes_li_lf_ri_rf(free, ground_inst, + ground_inst, ground_inst), + Context = unify_context(umc_explicit, []), + Expr = unify(Out, rhs_var(Acc), UnifyMode, assign(Out,Acc), Context), + set_of_var.list_to_set([Out, Acc], NonLocalVars), + InstMapDelta = instmap_delta_bind_var(Out), + goal_info_init(NonLocalVars, InstMapDelta, detism_det, purity_pure, Info), + Goal = hlds_goal(Expr, Info). + +%---------------------------------------------------------------------------% + + % Given the top level structure of the goal create new version + % with new base and recursive cases plugged in. + % +:- pred accu_top_level(top_level::in, hlds_goal::in, + hlds_goal::in, hlds_goal::in, hlds_goal::in, + hlds_goal::in, hlds_goal::out, hlds_goal::out) is det. + +accu_top_level(TopLevel, Goal, OrigBaseGoal, OrigRecGoal, + NewBaseGoal, NewRecGoal, OrigGoal, NewGoal) :- + ( + TopLevel = switch_base_rec, + ( if + Goal = hlds_goal(switch(Var, CanFail, Cases0), GoalInfo), + Cases0 = [case(IdA, [], _), case(IdB, [], _)] + then + OrigCases = [case(IdA, [], OrigBaseGoal), + case(IdB, [], OrigRecGoal)], + OrigGoal = hlds_goal(switch(Var, CanFail, OrigCases), GoalInfo), + + NewCases = [case(IdA, [], NewBaseGoal), case(IdB, [], NewRecGoal)], + NewGoal = hlds_goal(switch(Var, CanFail, NewCases), GoalInfo) + else + unexpected($pred, "not the correct top level") + ) + ; + TopLevel = switch_rec_base, + ( if + Goal = hlds_goal(switch(Var, CanFail, Cases0), GoalInfo), + Cases0 = [case(IdA, [], _), case(IdB, [], _)] + then + OrigCases = [case(IdA, [], OrigRecGoal), + case(IdB, [], OrigBaseGoal)], + OrigGoal = hlds_goal(switch(Var, CanFail, OrigCases), GoalInfo), + + NewCases = [case(IdA, [], NewRecGoal), case(IdB, [], NewBaseGoal)], + NewGoal = hlds_goal(switch(Var, CanFail, NewCases), GoalInfo) + else + unexpected($pred, "not the correct top level") + ) + ; + TopLevel = disj_base_rec, + ( if + Goal = hlds_goal(disj(Goals), GoalInfo), + Goals = [_, _] + then + OrigGoals = [OrigBaseGoal, OrigRecGoal], + OrigGoal = hlds_goal(disj(OrigGoals), GoalInfo), + + NewGoals = [NewBaseGoal, NewRecGoal], + NewGoal = hlds_goal(disj(NewGoals), GoalInfo) + else + unexpected($pred, "not the correct top level") + ) + ; + TopLevel = disj_rec_base, + ( if + Goal = hlds_goal(disj(Goals), GoalInfo), + Goals = [_, _] + then + OrigGoals = [OrigRecGoal, OrigBaseGoal], + OrigGoal = hlds_goal(disj(OrigGoals), GoalInfo), + + NewGoals = [NewRecGoal, NewBaseGoal], + NewGoal = hlds_goal(disj(NewGoals), GoalInfo) + else + unexpected($pred, "not the correct top level") + ) + ; + TopLevel = ite_base_rec, + ( if Goal = hlds_goal(if_then_else(Vars, Cond, _, _), GoalInfo) then + OrigGoal = hlds_goal(if_then_else(Vars, Cond, + OrigBaseGoal, OrigRecGoal), GoalInfo), + NewGoal = hlds_goal(if_then_else(Vars, Cond, + NewBaseGoal, NewRecGoal), GoalInfo) + else + unexpected($pred, "not the correct top level") + ) + ; + TopLevel = ite_rec_base, + ( if Goal = hlds_goal(if_then_else(Vars, Cond, _, _), GoalInfo) then + OrigGoal = hlds_goal(if_then_else(Vars, Cond, + OrigRecGoal, OrigBaseGoal), GoalInfo), + NewGoal = hlds_goal(if_then_else(Vars, Cond, + NewRecGoal, NewBaseGoal), GoalInfo) + else + unexpected($pred, "not the correct top level") + ) + ). + +%---------------------------------------------------------------------------% + + % Place the accumulator version of the predicate in the HLDS. + % +:- pred update_accumulator_pred(pred_id::in, proc_id::in, + hlds_goal::in, module_info::in, module_info::out) is det. + +update_accumulator_pred(NewPredId, NewProcId, AccGoal, !ModuleInfo) :- + module_info_pred_proc_info(!.ModuleInfo, NewPredId, NewProcId, + PredInfo, ProcInfo0), + proc_info_set_goal(AccGoal, ProcInfo0, ProcInfo1), + requantify_proc_general(ordinary_nonlocals_no_lambda, ProcInfo1, ProcInfo), + module_info_set_pred_proc_info(NewPredId, NewProcId, + PredInfo, ProcInfo, !ModuleInfo). + +%---------------------------------------------------------------------------% +%---------------------------------------------------------------------------% + + % accu_rename(Ids, Subst, From, Initial): + % + % Return a goal_store, Final, which is the result of looking up each + % member of set of goal_ids, Ids, in the goal_store, From, applying + % the substitution and then storing the goal into the goal_store, Initial. + % +:- func accu_rename(list(accu_goal_id), accu_subst, + accu_goal_store, accu_goal_store) = accu_goal_store. + +accu_rename(Ids, Subst, From, Initial) = Final :- + list.foldl( + ( pred(Id::in, GS0::in, GS::out) is det :- + goal_store_lookup(From, Id, stored_goal(Goal0, InstMap)), + rename_some_vars_in_goal(Subst, Goal0, Goal), + goal_store_det_insert(Id, stored_goal(Goal, InstMap), GS0, GS) + ), Ids, Initial, Final). + + % Return all the goal_ids which belong in the base case. + % +:- func base_case_ids(accu_goal_store) = list(accu_goal_id). + +base_case_ids(GS) = Base :- + solutions.solutions( + ( pred(Key::out) is nondet :- + goal_store_member(GS, Key, _Goal), + Key = accu_goal_id(accu_base, _) + ), Base). + +:- func base_case_ids_set(accu_goal_store) = set(accu_goal_id). + +base_case_ids_set(GS) = set.list_to_set(base_case_ids(GS)). + + % Given a list of goal_ids, return the list of hlds_goals from + % the goal_store. + % +:- func accu_goal_list(list(accu_goal_id), accu_goal_store) = list(hlds_goal). + +accu_goal_list(Ids, GS) = Goals :- + list.map( + ( pred(Key::in, G::out) is det :- + goal_store_lookup(GS, Key, stored_goal(G, _)) + ), Ids, Goals). + +%---------------------------------------------------------------------------% +%---------------------------------------------------------------------------% + +:- pred calculate_goal_info(hlds_goal_expr::in, hlds_goal::out) is det. + +calculate_goal_info(GoalExpr, hlds_goal(GoalExpr, GoalInfo)) :- + ( if GoalExpr = conj(plain_conj, GoalList) then + goal_list_nonlocals(GoalList, NonLocals), + goal_list_instmap_delta(GoalList, InstMapDelta), + goal_list_determinism(GoalList, Detism), + + goal_info_init(NonLocals, InstMapDelta, Detism, purity_pure, GoalInfo) + else + unexpected($pred, "not a conj") + ). + +%---------------------------------------------------------------------------% +%---------------------------------------------------------------------------% + +:- func chain_subst(accu_subst, accu_subst) = accu_subst. + +chain_subst(AtoB, BtoC) = AtoC :- + map.keys(AtoB, Keys), + chain_subst_2(Keys, AtoB, BtoC, AtoC). + +:- pred chain_subst_2(list(A)::in, map(A, B)::in, map(B, C)::in, + map(A, C)::out) is det. + +chain_subst_2([], _, _, AtoC) :- + map.init(AtoC). +chain_subst_2([A | As], AtoB, BtoC, AtoC) :- + chain_subst_2(As, AtoB, BtoC, AtoC0), + map.lookup(AtoB, A, B), + ( if map.search(BtoC, B, C) then + map.det_insert(A, C, AtoC0, AtoC) + else + AtoC = AtoC0 + ). + +%---------------------------------------------------------------------------% +:- end_module transform_hlds.accumulator. +%---------------------------------------------------------------------------% -- 2.31.1