johnco3
johnco3

Reputation: 2652

libTooling application to rewrite statements while preserving leading and trailing comments

I am writing a source to source transformation tool using clang libTooling to transform C source code. This involves rewriting the following statement types: clang::ifStmt, clang::whileStmt, clang::forStmt and clang::doStmt. Each of these statements contains a condition. The example included below shows how the condition is handled for the clang::ifStmt, but the same code applies for the other statements.

I need to access the entire condition text (including any leading and trailing comments) and rewrite if (condition_text) as if(/*BEGIN*/condition_text/*END*/) while preserving any leading or trailing comment text from the condition_text (for example if (/*FOO*/i<10/*BAR*/) should be rewritten as if (/*BEGIN*//*FOO*/i<10/*BAR*//*END*/). I do not understand why I cannot use the clang::Lexer to search backwards and forwards through raw tokens to find the leading and trailing comments. The IfStmt callback from the RecursiveASTVisitor is shown as an example below:

    //! Visitor callback for 'clang::IfStmt'.
    bool VisitIfStmt(const clang::IfStmt *IS) const {
        if (IS->getCond()) {
            if (const auto& [SR, condString] = getSourceFromStmt(
                IS->getCond()); !condString.empty()) {
                const auto& SM = mContext.getSourceManager();
                const auto& LO = mContext.getLangOpts();
                // Replace the condition string.
                const auto probeText = std::format(
                    "/*BEGIN*/{}/*END*/"
                    , condString
                    , gProbeIndex++);
                mRewriter.ReplaceText(SR, probeText);
            }
        }
        // Returning true continues the traversal.
        return true;
    }

This uses a helper function to return an updated SourceRange and ConditionString which is where I am having problems.

Running the above against some very simple C code shown below, (note that the C code contains TRUE/FALSE macros so we have to be careful to include expansion locations:

#define TRUE 1
#define FALSE 0

void foo() {
    // Leading Macro test
    if (TRUE == 1) {
    }

    // Leading Comment, trailing Macro
    if (/*COMMENT*/0 == FALSE) {
    }

    // trailing Comment after Macro
    if (0 == TRUE/*COMMENT*/) {
    }

    // Leading comment
    if (/*COMMENT*/2 < t) {
    }

    // Trailing comment
    if (t < 2 /*COMMENT*/) {
    }
}

produces the following rewritten text.

#define TRUE 1
#define FALSE 0

void foo() {
    int t = 0;
    // Leading Macro test
    if (/*BEGIN*/TRUE == 1/*END*/) {
    }

    // Leading Comment, trailing Macro
    if (/*COMMENT*//*BEGIN*/0 == FALSE/*END*/) {
    }

    // trailing Comment after Macro
    if (/*BEGIN*/0 == TRUE/*COMMENT*/)/*END*/ {
    }

    // Leading comment
    if (/*COMMENT*//*BEGIN*/2 < t/*END*/) {
    }

    // Trailing comment
    if (/*BEGIN*/t < 2/*END*/ /*COMMENT*/) {
    }
}

The full application source is shown below:

// SYSTEM INCLUDES
#include <iostream>
#include <filesystem>
#include <clang/AST/ASTContext.h>
#include <clang/AST/RecursiveASTVisitor.h>
#include <clang/Basic/Diagnostic.h>
#include <clang/Basic/DiagnosticOptions.h>
#include <clang/Basic/SourceLocation.h>
#include <clang/Basic/SourceManager.h>
#include <clang/Frontend/ASTUnit.h>
#include <clang/Frontend/CompilerInstance.h>
#include <clang/Rewrite/Core/Rewriter.h>
#include <clang/Serialization/PCHContainerOperations.h>
using namespace llvm;

namespace {
    //! The tool category.
    cl::OptionCategory gToolCategory("Tool Category");

    //! The probe index
    unsigned gProbeIndex = 0u;
}

namespace tooling {
    class CProbeVisitor : public clang::RecursiveASTVisitor<CProbeVisitor> {
    public:
        explicit CProbeVisitor(
            clang::ASTContext& rContext,
            clang::Rewriter& rRewriter)
            : mContext{rContext}
            , mRewriter{rRewriter}
        {}

        static bool shouldTraversePostOrder() {
            // Must return true to traverse the AST in post-order.
            return true;
        }

        bool TraverseFunctionDecl(clang::FunctionDecl *FD) {
            const auto& SM = mContext.getSourceManager();
            if (!SM.isInMainFile(SM.getExpansionLoc(FD->getLocation()))) {
                // Skip this FunctionDecl and continue traversal.
                return true;
            }

            // Breadcrumb to access function name in other VisitStmt callbacks.
            mCurrentFunction = FD;

            // Now traverse the function body - should not need to check
            // locations in the visited statement callbacks.
            const bool shouldContinue = 
                RecursiveASTVisitor::TraverseFunctionDecl(FD);

            if (!shouldContinue) {
                // Stop traversal.
                return false;
            }

            // This is called when exiting the function.
            // You can put your function exit logic here.
            return true;  // Continue traversal.
        }

        //! Visitor callback for 'clang::IfStmt'.
        bool VisitIfStmt(const clang::IfStmt *IS) const {
            if (IS->getCond()) {
                if (const auto& [SR, condString] = getSourceFromStmt(
                    IS->getCond()); !condString.empty()) {
                    const auto& SM = mContext.getSourceManager();
                    const auto& LO = mContext.getLangOpts();
                    // Replace the condition string.
                    const auto probeText = std::format(
                        "/*BEGIN*/{}/*END*/"
                        , condString
                        , gProbeIndex++);
                    mRewriter.ReplaceText(SR, probeText);
                }
            }
            // Returning true continues the traversal.
            return true;
        }

        //! Visitor callback for "clang::ForStmt".
        bool VisitForStmt(const clang::ForStmt *FS) const {
            if (FS->getCond()) {
                if (const auto& [SR, condString] = getSourceFromStmt(
                    FS->getCond()); !condString.empty()) {
                    const auto& SM = mContext.getSourceManager();
                    const auto& LO = mContext.getLangOpts();
                    // Replace the condition string.
                    const auto probeText = std::format(
                        "/*BEGIN*/{}/*END*/"
                        , condString
                        , gProbeIndex++);
                    mRewriter.ReplaceText(SR, probeText);
                }
            }
            // Returning true continues the traversal.
            return true;
        }

        //! Visitor callback for "clang::TranslationUnitDecl".
        static bool VisitTranslationUnitDecl(const clang::TranslationUnitDecl *TU) {
            return true;
        }

    private:
        //! Helper method to get the stmt source (accounting for prior rewrites & expansion locations).
        std::pair<clang::CharSourceRange, std::string> getSourceFromStmt(const clang::Stmt* stmt) const {
            const auto& SM = mContext.getSourceManager();
            // for some reason this is start of condition (not leading comments up to the end of the trailing comments)
            const auto SR = getSourceRangeWithComments(stmt);
            if (SM.isWrittenInSameFile(SR.getBegin(), SR.getEnd())) {
                // Get the text *after* taking prior rewrites into account.
                return { SR, mRewriter.getRewrittenText(SR) };
            }
            return {};
        }

        auto getSourceRangeWithComments(
            const clang::Stmt* stmt) const -> clang::CharSourceRange {
            const auto& SM = mContext.getSourceManager();
            const auto& LO = mContext.getLangOpts();

            // Get the expansion SourceRange of the expression.
            const auto SR = SM.getExpansionRange(stmt->getSourceRange());
            auto beginLoc = SR.getBegin();
            auto endLoc = SR.getEnd();

            // Adjust the beginning location backwards from the SR.getBegin.
            beginLoc = clang::Lexer::GetBeginningOfToken(beginLoc, SM, LO);
            while (SM.isWrittenInSameFile(beginLoc, SR.getBegin())) {
                clang::Token token;
                if (!clang::Lexer::getRawToken(beginLoc.getLocWithOffset(-1), token, SM, LO) &&
                    token.is(clang::tok::comment)) {
                    beginLoc = token.getLocation();
                } else {
                    // failed to get the raw token.
                    break;
                }
            }

            // Adjust the end location to the end of the trailing comments
            endLoc = clang::Lexer::getLocForEndOfToken(endLoc, 0, SM, LO);
            while (SM.isWrittenInSameFile(SR.getEnd(), endLoc)) {
                clang::Token token;
                if (!clang::Lexer::getRawToken(endLoc, token, SM, LO) &&
                    token.is(clang::tok::comment)) {
                    endLoc = clang::Lexer::getLocForEndOfToken(token.getEndLoc(), 0, SM, LO);
                } else {
                    break;
                }
            }

            // Now beginLoc and endLoc include the leading and trailing comments
            return {{beginLoc, endLoc}, false};
        }

        clang::ASTContext& mContext;
        clang::Rewriter& mRewriter;
        const clang::FunctionDecl *mCurrentFunction = nullptr;
    };
}

// This is all boilerplate for a program using the Clang C++ API
// ("libTooling") but not using the "tooling" part specifically.
int main(int argc, char const **argv)
{
    // Copy the arguments into a vector of char pointers since that is
    // what 'createInvocationFromCommandLine' wants.
    std::vector<char const *> commandLine;
    {
        // Path to the 'clang' binary that I am behaving like.  This path is
        // used to compute the location of compiler headers like stddef.h.
        // The Makefile sets 'CLANG_LLVM_INSTALL_DIR' on the compilation
        // command line.
        //commandLine.push_back("C:/tools/llvm-project/build-host/debug/bin/clang");
        commandLine.push_back("C:/tools/llvm/bin/clang");

        for (int i = 1; i < argc; ++i) {
            commandLine.push_back(argv[i]);
        }
    }

    // Parse the command line options.
    const std::shared_ptr<clang::CompilerInvocation> compilerInvocation(
        clang::createInvocation(llvm::ArrayRef(commandLine)));
    if (!compilerInvocation) {
        // Command line parsing errors have already been printed.
        return 2;
    }

    // Boilerplate setup for 'LoadFromCompilerInvocationAction'.
    const auto pchContainerOps = std::make_shared<clang::PCHContainerOperations>();
    const clang::IntrusiveRefCntPtr<clang::DiagnosticsEngine> diagnosticsEngine(
        clang::CompilerInstance::createDiagnostics(
            new clang::DiagnosticOptions));

    // Run the Clang parser to produce an AST.
    const std::unique_ptr<clang::ASTUnit> ast(
        clang::ASTUnit::LoadFromCompilerInvocationAction(
            compilerInvocation,
            pchContainerOps,
            diagnosticsEngine));

    if (ast == nullptr || diagnosticsEngine->getNumErrors() > 0) {
        // Error messages have already been printed.
        return 2;
    }

    clang::ASTContext& astContext = ast->getASTContext();
    const auto& SM = astContext.getSourceManager();
    clang::Rewriter rewriter(astContext.getSourceManager(), astContext.getLangOpts());


    tooling::CProbeVisitor visitor(astContext, rewriter);
    visitor.TraverseDecl(astContext.getTranslationUnitDecl());

    const auto MainFileID = SM.getMainFileID();
    const auto MainFileRange = clang::SourceRange(
        SM.getLocForStartOfFile(MainFileID),
        SM.getLocForEndOfFile(MainFileID));

    const auto FinalSourceCode = rewriter.getRewrittenText(MainFileRange);

    // Print the final source code to the console
    std::cout << FinalSourceCode << '\n';

    return 0;
}

Upvotes: 1

Views: 116

Answers (1)

Scott McPeak
Scott McPeak

Reputation: 12863

Difficulties lexing backward

The code in the question tries to lex the input backwards in order to go from the first token in the for condition to, and past, any preceding comment:

    if (/*COMMENT*/2 < t) {
        ^          ^
        |          start
        |
       goal

This is basically impossible, since the lexical rules for C do not allow for an unambiguous backward analysis. One issue is that comments do not nest, so if you see /**/ (going backwards) you cannot tell whether to stop because the complete comment, when seen going forwards, could be /*/**/. Likewise, if you see "" (going backwards), you have to check at least one more character because it could be \"". This is probably just the tip of the iceberg, but plenty to conclude the approach is doomed.

One might imagine that clang::Lexer provides access to token information recorded during the original parse, but it does not. Instead, it re-analyzes the source text from scratch. It is possible to hook into the original lexical analysis, to a degree, by using PPCallbacks, but that's overkill for this situation.

Getting the if condition

Fortunately, we don't have to go backward because we can go forward. In the case of an if statement, Clang provides getLParenLoc() and getRParenLoc(), which is all we need. Simply put the instrumentation immediately inside each of those.

Getting the for condition

The for condition is much harder. We want to know where the two semicolons are, but the Clang AST does not store their locations directly. So we need to use a more complicated procedure.

To get the first semicolon:

  • Get the start of the last token of the initializer.

    • If the initializer is present, forStmt->getInit()->getEndLoc() yields the first character of its last token.

    • Otherwise, use the left paren of the for, which is at forStmt->getLParenLoc().

  • Walk forward one token to the semicolon by using clang::Lexer::findNextToken. Generally, this method accepts the first character of one token and provides details about the token after that one, so you can iterate by hopping from first character to first character.

    • If macro shenanigans are in play, this might not be a semicolon. There is no general solution to that because macros can be arbitrarily complicated, but at least Token::is can be used to check (the code in the question already does so).

To get the second semicolon:

  • Get the first character of the last token of the condition as forStmt->getCond()->getEndLoc().

    • If there is no condition, use the first semicolon instead.
  • Walk forward to the second semicolon using findNextToken.

Now, place the instrumentation just inside this pair of semicolons.

Complete demonstration program

The question contains a complete program (thank you!), which I've modified. All of the added code is guarded by #ifdef CHANGED. It uses the procedure described above to find the conditions of if and for statements; the getConditionSourceRange method is where all of the action is.

It then inserts /*BEGIN*/ and /*END*/ comments surrounding those conditions using logic that's basically the same as in the original.

// SYSTEM INCLUDES
#include <iostream>
#include <filesystem>
#include <clang/AST/ASTContext.h>
#include <clang/AST/RecursiveASTVisitor.h>
#include <clang/Basic/Diagnostic.h>
#include <clang/Basic/DiagnosticOptions.h>
#include <clang/Basic/SourceLocation.h>
#include <clang/Basic/SourceManager.h>
#include <clang/Frontend/ASTUnit.h>
#include <clang/Frontend/CompilerInstance.h>
#include <clang/Rewrite/Core/Rewriter.h>
#include <clang/Serialization/PCHContainerOperations.h>
using namespace llvm;

namespace {
    //! The tool category.
    cl::OptionCategory gToolCategory("Tool Category");

    //! The probe index
    unsigned gProbeIndex = 0u;
}

namespace tooling {
    class CProbeVisitor : public clang::RecursiveASTVisitor<CProbeVisitor> {
    public:
        explicit CProbeVisitor(
            clang::ASTContext& rContext,
            clang::Rewriter& rRewriter)
            : mContext{rContext}
            , mRewriter{rRewriter}
        {}

        static bool shouldTraversePostOrder() {
            // Must return true to traverse the AST in post-order.
            return true;
        }

        bool TraverseFunctionDecl(clang::FunctionDecl *FD) {
            const auto& SM = mContext.getSourceManager();
            if (!SM.isInMainFile(SM.getExpansionLoc(FD->getLocation()))) {
                // Skip this FunctionDecl and continue traversal.
                return true;
            }

            // Breadcrumb to access function name in other VisitStmt callbacks.
            mCurrentFunction = FD;

            // Now traverse the function body - should not need to check
            // locations in the visited statement callbacks.
            const bool shouldContinue = 
                RecursiveASTVisitor::TraverseFunctionDecl(FD);

            if (!shouldContinue) {
                // Stop traversal.
                return false;
            }

            // This is called when exiting the function.
            // You can put your function exit logic here.
            return true;  // Continue traversal.
        }

        //! Visitor callback for 'clang::IfStmt'.
        bool VisitIfStmt(const clang::IfStmt *IS) const {
            if (IS->getCond()) {
#ifdef CHANGED
                if (const auto& [SR, condString] = getConditionSourceFromStmt(IS);
                    !condString.empty()) {
                    // Replace the condition string.
                    const auto probeText =
                        std::string("/*BEGIN*/") +
                        condString +
                        "/*END*/";
                    gProbeIndex++;
                    mRewriter.ReplaceText(SR, probeText);
                }
#else // !CHANGED
                if (const auto& [SR, condString] = getSourceFromStmt(
                    IS->getCond()); !condString.empty()) {
                    const auto& SM = mContext.getSourceManager();
                    const auto& LO = mContext.getLangOpts();
                    // Replace the condition string.
                    const auto probeText = std::format(
                        "/*BEGIN*/{}/*END*/"
                        , condString
                        , gProbeIndex++);
                    mRewriter.ReplaceText(SR, probeText);
                }
#endif // CHANGED
            }
            // Returning true continues the traversal.
            return true;
        }

        //! Visitor callback for "clang::ForStmt".
        bool VisitForStmt(const clang::ForStmt *FS) const {
            if (FS->getCond()) {
#ifdef CHANGED
                if (const auto& [SR, condString] = getConditionSourceFromStmt(FS);
                    !condString.empty()) {
                    // Replace the condition string.
                    const auto probeText =
                        std::string("/*BEGIN*/") +
                        condString +
                        "/*END*/";
                    gProbeIndex++;
                    mRewriter.ReplaceText(SR, probeText);
                }
#else // !CHANGED
                if (const auto& [SR, condString] = getSourceFromStmt(
                    FS->getCond()); !condString.empty()) {
                    const auto& SM = mContext.getSourceManager();
                    const auto& LO = mContext.getLangOpts();
                    // Replace the condition string.
                    const auto probeText = std::format(
                        "/*BEGIN*/{}/*END*/"
                        , condString
                        , gProbeIndex++);
                    mRewriter.ReplaceText(SR, probeText);
                }
#endif // CHANGED
            }
            // Returning true continues the traversal.
            return true;
        }

        //! Visitor callback for "clang::TranslationUnitDecl".
        static bool VisitTranslationUnitDecl(const clang::TranslationUnitDecl *TU) {
            return true;
        }

    private:
#ifdef CHANGED
        std::string locStr(clang::SourceLocation loc) const
        {
          return loc.printToString(mContext.getSourceManager());
        }

        // Similar to `getSourceFromStmt`, below.
        std::pair<clang::CharSourceRange, std::string> getConditionSourceFromStmt(
          const clang::Stmt* stmt) const
        {
            const auto& SM = mContext.getSourceManager();
            const auto SR = getConditionSourceRange(stmt);
            if (SM.isWrittenInSameFile(SR.getBegin(), SR.getEnd())) {
                // Get the text *after* taking prior rewrites into account.
                return { SR, mRewriter.getRewrittenText(SR) };
            }
            return {};
        }

        // Sort of like `getSourceRangeWithComments`, below.
        clang::CharSourceRange getConditionSourceRange(const clang::Stmt* stmt) const
        {
            const auto& SM = mContext.getSourceManager();
            const auto& LO = mContext.getLangOpts();

            if (auto ifStmt = dyn_cast<clang::IfStmt>(stmt)) {
                clang::SourceLocation lparenLoc = ifStmt->getLParenLoc();
                clang::SourceLocation rparenLoc = ifStmt->getRParenLoc();

                // Move past the '(' itself.
                lparenLoc = lparenLoc.getLocWithOffset(+1);

                return {{lparenLoc, rparenLoc}, false};
            }

            else if (auto forStmt = dyn_cast<clang::ForStmt>(stmt)) {
                // Get the start of the last token in the initializer,
                // or use the left paren if there is no initializer.
                clang::SourceLocation initLastTokenLoc;
                if (clang::Stmt const *init = forStmt->getInit()) {
                    initLastTokenLoc = init->getEndLoc();
                }
                else {
                    initLastTokenLoc = forStmt->getLParenLoc();
                }

                // Get the next token, which should be the first
                // semicolon.
                std::optional<clang::Token> firstSemicolonToken =
                    clang::Lexer::findNextToken(initLastTokenLoc, SM, LO);
                if (!firstSemicolonToken) {
                    std::cout << "getRawToken failed, bailing out\n";
                    return {};
                }
                std::cout << "first semi is a semi: "
                          << firstSemicolonToken->is(clang::tok::semi) << "\n";

                // Get the location immediately after the first semi.
                // Don't just assume it is at +1 in case some macro
                // shenanigans are afoot (we can't hande everything
                // nefarious, but some cases we can).
                clang::SourceLocation afterFirstSemicolonLoc =
                    clang::Lexer::getLocForEndOfToken(firstSemicolonToken->getLocation(), 0, SM, LO);

                // Get the last token of the condition.
                clang::SourceLocation condLastTokenLoc;
                if (clang::Expr const *cond = forStmt->getCond()) {
                    condLastTokenLoc = cond->getEndLoc();
                }
                else {
                    // No condition.  We could walk forward from the
                    // first semicolon, but the caller of this function
                    // already filters out `for` statements with no
                    // condition, so it does not matter.
                    std::cout << "no cond, bailing out\n";
                    return {};
                }

                // Get the next token, which should be the second semi.
                std::optional<clang::Token> secondSemicolonToken =
                    clang::Lexer::findNextToken(condLastTokenLoc, SM, LO);
                if (!secondSemicolonToken) {
                    std::cout << "getRawToken failed, bailing out\n";
                    return {};
                }
                std::cout << "second semi is a semi: "
                          << secondSemicolonToken->is(clang::tok::semi) << "\n";
                clang::SourceLocation secondSemicolonLoc =
                    secondSemicolonToken->getLocation();

                std::cout << "forStmt:" <<
                    "\n  InitLastToken: " << locStr(initLastTokenLoc) <<
                    "\n  AfterFirstSemi: " << locStr(afterFirstSemicolonLoc) <<
                    "\n  CondLastToken: " << locStr(condLastTokenLoc) <<
                    "\n  SecondSemi: " << locStr(secondSemicolonLoc) <<
                    "\n";

                return {{afterFirstSemicolonLoc, secondSemicolonLoc}, false};
            }

            else {
                assert(!"unhandled stmt kind");
            }

            // Not reached.
            return {};
        }

#endif // CHANGED
        //! Helper method to get the stmt source (accounting for prior rewrites & expansion locations).
        std::pair<clang::CharSourceRange, std::string> getSourceFromStmt(const clang::Stmt* stmt) const {
            const auto& SM = mContext.getSourceManager();
            // for some reason this is start of condition (not leading comments up to the end of the trailing comments)
            const auto SR = getSourceRangeWithComments(stmt);
            if (SM.isWrittenInSameFile(SR.getBegin(), SR.getEnd())) {
                // Get the text *after* taking prior rewrites into account.
                return { SR, mRewriter.getRewrittenText(SR) };
            }
            return {};
        }

        auto getSourceRangeWithComments(
            const clang::Stmt* stmt) const -> clang::CharSourceRange {
            const auto& SM = mContext.getSourceManager();
            const auto& LO = mContext.getLangOpts();

            // Get the expansion SourceRange of the expression.
            const auto SR = SM.getExpansionRange(stmt->getSourceRange());
            auto beginLoc = SR.getBegin();
            auto endLoc = SR.getEnd();

            // Adjust the beginning location backwards from the SR.getBegin.
            beginLoc = clang::Lexer::GetBeginningOfToken(beginLoc, SM, LO);
            while (SM.isWrittenInSameFile(beginLoc, SR.getBegin())) {
                clang::Token token;
                if (!clang::Lexer::getRawToken(beginLoc.getLocWithOffset(-1), token, SM, LO) &&
                    token.is(clang::tok::comment)) {
                    beginLoc = token.getLocation();
                } else {
                    // failed to get the raw token.
                    break;
                }
            }

            // Adjust the end location to the end of the trailing comments
            endLoc = clang::Lexer::getLocForEndOfToken(endLoc, 0, SM, LO);
            while (SM.isWrittenInSameFile(SR.getEnd(), endLoc)) {
                clang::Token token;
                if (!clang::Lexer::getRawToken(endLoc, token, SM, LO) &&
                    token.is(clang::tok::comment)) {
                    endLoc = clang::Lexer::getLocForEndOfToken(token.getEndLoc(), 0, SM, LO);
                } else {
                    break;
                }
            }

            // Now beginLoc and endLoc include the leading and trailing comments
            return {{beginLoc, endLoc}, false};
        }

        clang::ASTContext& mContext;
        clang::Rewriter& mRewriter;
        const clang::FunctionDecl *mCurrentFunction = nullptr;
    };
}

// This is all boilerplate for a program using the Clang C++ API
// ("libTooling") but not using the "tooling" part specifically.
int main(int argc, char const **argv)
{
    // Copy the arguments into a vector of char pointers since that is
    // what 'createInvocationFromCommandLine' wants.
    std::vector<char const *> commandLine;
    {
        // Path to the 'clang' binary that I am behaving like.  This path is
        // used to compute the location of compiler headers like stddef.h.
        // The Makefile sets 'CLANG_LLVM_INSTALL_DIR' on the compilation
        // command line.
        //commandLine.push_back("C:/tools/llvm-project/build-host/debug/bin/clang");
        commandLine.push_back("C:/tools/llvm/bin/clang");

        for (int i = 1; i < argc; ++i) {
            commandLine.push_back(argv[i]);
        }
    }

    // Parse the command line options.
    const std::shared_ptr<clang::CompilerInvocation> compilerInvocation(
        clang::createInvocation(llvm::ArrayRef(commandLine)));
    if (!compilerInvocation) {
        // Command line parsing errors have already been printed.
        return 2;
    }

    // Boilerplate setup for 'LoadFromCompilerInvocationAction'.
    const auto pchContainerOps = std::make_shared<clang::PCHContainerOperations>();
    const clang::IntrusiveRefCntPtr<clang::DiagnosticsEngine> diagnosticsEngine(
        clang::CompilerInstance::createDiagnostics(
            new clang::DiagnosticOptions));

    // Run the Clang parser to produce an AST.
    const std::unique_ptr<clang::ASTUnit> ast(
        clang::ASTUnit::LoadFromCompilerInvocationAction(
            compilerInvocation,
            pchContainerOps,
            diagnosticsEngine));

    if (ast == nullptr || diagnosticsEngine->getNumErrors() > 0) {
        // Error messages have already been printed.
        return 2;
    }

    clang::ASTContext& astContext = ast->getASTContext();
    const auto& SM = astContext.getSourceManager();
    clang::Rewriter rewriter(astContext.getSourceManager(), astContext.getLangOpts());


    tooling::CProbeVisitor visitor(astContext, rewriter);
    visitor.TraverseDecl(astContext.getTranslationUnitDecl());

    const auto MainFileID = SM.getMainFileID();
    const auto MainFileRange = clang::SourceRange(
        SM.getLocForStartOfFile(MainFileID),
        SM.getLocForEndOfFile(MainFileID));

    const auto FinalSourceCode = rewriter.getRewrittenText(MainFileRange);

    // Print the final source code to the console
    std::cout << FinalSourceCode << '\n';

    return 0;
}

Further generalizing to handle switch, do, and while should be straightforward as they all follow the pattern of if, exposing methods to get the paren locations.

(My changes happen to replace std::format with string concatenation because the toolchain I'm using at the moment doesn't have std::format but there's nothing wrong with using it.)

Sample output

When run on the following input (expanded from the one in the question):

int t;

#define TRUE 1
#define FALSE 0

#define SEMI ;

void foo() {
    // Leading Macro test
    if (TRUE == 1) {
    }

    // Leading Comment, trailing Macro
    if (/*COMMENT*/0 == FALSE) {
    }

    // trailing Comment after Macro
    if (0 == TRUE/*COMMENT*/) {
    }

    // Leading comment
    if (/*COMMENT*/2 < t) {
    }

    // Trailing comment
    if (t < 2 /*COMMENT*/) {
    }

    for (t=10 ; t<25; ++t) {}
    for ( ; t<25; ++t) {}
    for (t=10 ; t<25/*COMMENT*/; ++t) {}
    for (t=10 ;/*COMMENT*/ t<25; ++t) {}

    // Obnoxious...
    for (t=10 SEMI t<25 SEMI ++t) {}
}

it produces this output:

first semi is a semi: 1
second semi is a semi: 1
forStmt:
  InitLastToken: test.c:29:12
  AfterFirstSemi: test.c:29:16
  CondLastToken: test.c:29:19
  SecondSemi: test.c:29:21
first semi is a semi: 1
second semi is a semi: 1
forStmt:
  InitLastToken: test.c:30:9
  AfterFirstSemi: test.c:30:12
  CondLastToken: test.c:30:15
  SecondSemi: test.c:30:17
first semi is a semi: 1
second semi is a semi: 1
forStmt:
  InitLastToken: test.c:31:12
  AfterFirstSemi: test.c:31:16
  CondLastToken: test.c:31:19
  SecondSemi: test.c:31:32
first semi is a semi: 1
second semi is a semi: 1
forStmt:
  InitLastToken: test.c:32:12
  AfterFirstSemi: test.c:32:16
  CondLastToken: test.c:32:30
  SecondSemi: test.c:32:32
first semi is a semi: 0
second semi is a semi: 0
forStmt:
  InitLastToken: test.c:35:12
  AfterFirstSemi: test.c:35:19
  CondLastToken: test.c:35:22
  SecondSemi: test.c:35:25
int t;

#define TRUE 1
#define FALSE 0

#define SEMI ;

void foo() {
    // Leading Macro test
    if (/*BEGIN*/TRUE == 1/*END*/) {
    }

    // Leading Comment, trailing Macro
    if (/*BEGIN*//*COMMENT*/0 == FALSE/*END*/) {
    }

    // trailing Comment after Macro
    if (/*BEGIN*/0 == TRUE/*COMMENT*//*END*/) {
    }

    // Leading comment
    if (/*BEGIN*//*COMMENT*/2 < t/*END*/) {
    }

    // Trailing comment
    if (/*BEGIN*/t < 2 /*COMMENT*//*END*/) {
    }

    for (t=10 ;/*BEGIN*/ t<25/*END*/; ++t) {}
    for ( ;/*BEGIN*/ t<25/*END*/; ++t) {}
    for (t=10 ;/*BEGIN*/ t<25/*COMMENT*//*END*/; ++t) {}
    for (t=10 ;/*BEGIN*//*COMMENT*/ t<25/*END*/; ++t) {}

    // Obnoxious...
    for (t=10 SEMI/*BEGIN*/ t<25 /*END*/SEMI ++t) {}
}

Upvotes: 1

Related Questions