summaryrefslogtreecommitdiffstats
path: root/gnu/llvm/clang/unittests/AST/ASTPrint.h
blob: c3b6b842316d9e085ec1191b5708e4e67eb219dc (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
//===- unittests/AST/ASTPrint.h ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Helpers to simplify testing of printing of AST constructs provided in the/
// form of the source code.
//
//===----------------------------------------------------------------------===//

#include "clang/AST/ASTContext.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/Tooling/Tooling.h"
#include "llvm/ADT/SmallString.h"
#include "gtest/gtest.h"

namespace clang {

using PolicyAdjusterType =
    Optional<llvm::function_ref<void(PrintingPolicy &Policy)>>;

static void PrintStmt(raw_ostream &Out, const ASTContext *Context,
                      const Stmt *S, PolicyAdjusterType PolicyAdjuster) {
  assert(S != nullptr && "Expected non-null Stmt");
  PrintingPolicy Policy = Context->getPrintingPolicy();
  if (PolicyAdjuster)
    (*PolicyAdjuster)(Policy);
  S->printPretty(Out, /*Helper*/ nullptr, Policy);
}

class PrintMatch : public ast_matchers::MatchFinder::MatchCallback {
  SmallString<1024> Printed;
  unsigned NumFoundStmts;
  PolicyAdjusterType PolicyAdjuster;

public:
  PrintMatch(PolicyAdjusterType PolicyAdjuster)
      : NumFoundStmts(0), PolicyAdjuster(PolicyAdjuster) {}

  void run(const ast_matchers::MatchFinder::MatchResult &Result) override {
    const Stmt *S = Result.Nodes.getNodeAs<Stmt>("id");
    if (!S)
      return;
    NumFoundStmts++;
    if (NumFoundStmts > 1)
      return;

    llvm::raw_svector_ostream Out(Printed);
    PrintStmt(Out, Result.Context, S, PolicyAdjuster);
  }

  StringRef getPrinted() const { return Printed; }

  unsigned getNumFoundStmts() const { return NumFoundStmts; }
};

template <typename T>
::testing::AssertionResult
PrintedStmtMatches(StringRef Code, const std::vector<std::string> &Args,
                   const T &NodeMatch, StringRef ExpectedPrinted,
                   PolicyAdjusterType PolicyAdjuster = None) {

  PrintMatch Printer(PolicyAdjuster);
  ast_matchers::MatchFinder Finder;
  Finder.addMatcher(NodeMatch, &Printer);
  std::unique_ptr<tooling::FrontendActionFactory> Factory(
      tooling::newFrontendActionFactory(&Finder));

  if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args))
    return testing::AssertionFailure()
           << "Parsing error in \"" << Code.str() << "\"";

  if (Printer.getNumFoundStmts() == 0)
    return testing::AssertionFailure() << "Matcher didn't find any statements";

  if (Printer.getNumFoundStmts() > 1)
    return testing::AssertionFailure()
           << "Matcher should match only one statement (found "
           << Printer.getNumFoundStmts() << ")";

  if (Printer.getPrinted() != ExpectedPrinted)
    return ::testing::AssertionFailure()
           << "Expected \"" << ExpectedPrinted.str() << "\", got \""
           << Printer.getPrinted().str() << "\"";

  return ::testing::AssertionSuccess();
}

} // namespace clang