From 13655183767c90c3c4b0dc90fec46342ea6ed2fb Mon Sep 17 00:00:00 2001 From: Pimyn Girgis Date: Fri, 27 Sep 2024 09:24:27 +0000 Subject: [PATCH] tools/syz-declextract: code restructuring for better types --- tools/syz-declextract/syz-declextract.cpp | 276 ++++++++++++---------- 1 file changed, 157 insertions(+), 119 deletions(-) diff --git a/tools/syz-declextract/syz-declextract.cpp b/tools/syz-declextract/syz-declextract.cpp index 8d6132f04e6c..0000efe8ad91 100644 --- a/tools/syz-declextract/syz-declextract.cpp +++ b/tools/syz-declextract/syz-declextract.cpp @@ -34,6 +34,8 @@ #include #include #include +#include +#include #include using namespace clang; @@ -106,31 +108,65 @@ bool beginsWith(const std::string_view &str, const std::string_view begin) { bool contains(const std::string_view &str, const std::string_view sub) { return str.find(sub) != std::string::npos; } -bool isPadding(const std::string &fieldName) { - return contains(fieldName, "pad") || contains(fieldName, "unused") || contains(fieldName, "_reserved"); +const std::string int8Subtype(std::string name) { + std::transform(name.begin(), name.end(), name.begin(), ::tolower); + if (endsWith(name, "enabled") || endsWith(name, "enable")) { + return "bool8"; + } + return "int8"; } - -const std::string makeArray(const std::string &type, const size_t len = 0) { - if (len == 1) { - return type; +const std::string int16Subtype(std::string name) { + std::transform(name.begin(), name.end(), name.begin(), ::tolower); + if (contains(name, "port")) { + return "sock_port"; } - if (len) { - return "array[" + type + ", " + std::to_string(len) + "]"; + return "int16"; +} +const std::string int32Subtype(std::string name) { + std::transform(name.begin(), name.end(), name.begin(), ::tolower); + if (contains(name, "ipv4")) { + return "ipv4_addr"; } - return "array[" + type + "]"; + if (endsWith(name, "dfd") && !endsWith(name, "oldfd")) { + return "fd_dir"; + } + if (endsWith(name, "fd")) { + if (endsWith(name, "ns_fd")) { + return "fd_namespace"; + } + return "fd"; + } + if (contains(name, "ifindex")) { + return "ifindex"; + } + if (endsWith(name, "enabled") || endsWith(name, "enable")) { + return "bool32"; + } + return "int32"; } +const std::string stringSubtype(std::string name) { + std::transform(name.begin(), name.end(), name.begin(), ::tolower); + if (contains(name, "ifname") || endsWith(name, "dev_name")) { + return "devname"; + } + return "string"; +} +const std::string int64Subtype() { return "int64"; } -const std::string makeArray(const std::string &type, const size_t min, const size_t max) { - if (min == max) { - return makeArray(type, min); +const std::string makeArray(const std::string &type, const size_t min = 0, const size_t max = -1) { + if (max != size_t(-1)) { + return "array[" + type + ", " + std::to_string(min) + ":" + std::to_string(max) + "]"; + } + if (min == 1) { + return type; } - return "array[" + type + ", " + std::to_string(min) + ":" + std::to_string(max) + "]"; + if (min) { + return "array[" + type + ", " + std::to_string(min) + "]"; + } + return "array[" + type + "]"; } const std::string makePtr(const std::string &dir, const std::string &type, bool isOpt = false) { - if (type == "void") { - return makePtr(dir, makeArray("int8")); - } std::string ptr = "ptr[" + dir + ", " + type; if (isOpt) { return ptr + ", opt]"; @@ -142,32 +178,78 @@ const std::string makeConst(const std::string &type, const std::string &val = "0 if (type.empty()) { return "const[" + val + "]"; } - if (type == "void") { - return "const[" + val + ", intptr]"; - } return "const[" + val + ", " + type + "]"; } -const std::string getSyzType(const std::string &ctype, const bool isSyscallParam = false) { +enum IntType { + INVALID_INT = 0, + INT_8 = 1, + INT_16 = 2, + INT_32 = 4, + INT_64 = 8, + INT_PTR, +}; + +IntType getIntType(const std::string &ctype, const bool isSyscallParam) { + // TODO: Handle arm32 passing 64bit arguments if (!isSyscallParam && (contains(ctype, "long long") || contains(ctype, "64"))) { - return contains(ctype, "be") ? "int64be" : "int64"; + return INT_64; } if (contains(ctype, "16") || contains(ctype, "short")) { - return contains(ctype, "be") ? "int16be" : "int16"; + return INT_16; } - if (contains(ctype, "8")) { - return contains(ctype, "be") ? "int8be" : "int8"; + if (contains(ctype, "8") || contains(ctype, "char")) { + return INT_8; } if (contains(ctype, "32") || contains(ctype, "int")) { - return contains(ctype, "be") ? "int32be" : "int32"; + return INT_32; } - if (contains(ctype, "void")) { - return "void"; + if (contains(ctype, "void") || contains(ctype, "long")) { + return INT_PTR; } - if (contains(ctype, "char")) { - return "int8"; + fprintf(stderr, "Unhandled int length for type: %s\n", ctype.c_str()); + exit(1); +} + +const std::string intSubtype(const std::string &name, const IntType len) { + switch (len) { + case INT_8: + return int8Subtype(name); + case INT_16: + return int16Subtype(name); + case INT_32: + return int32Subtype(name); + case INT_64: + return int64Subtype(); + case INT_PTR: + return "intptr"; + case INVALID_INT: + fprintf(stderr, "Invalid int type\n"); + exit(1); } - return "intptr"; +} + +const std::string getSyzType(const std::string &ctype, const std::string &name, const bool isSyscallParam, + const int bitFieldWidth = 0) { + IntType len = getIntType(ctype, isSyscallParam); + const int byteLen = len * 8; + if (INT_8 <= len && len <= INT_64 && contains(ctype, "be")) { + return "int" + std::to_string(byteLen) + "be"; + } + + std::string type; + if (bitFieldWidth) { + type = "int" + std::to_string(byteLen); + if (byteLen != bitFieldWidth) { + type += ":" + std::to_string(bitFieldWidth); + } + } else { + type = intSubtype(name, len); + } + if (name.empty() || contains(name, "pad") || contains(name, "unused") || contains(name, "_reserved")) { + return makeConst(isSyscallParam ? "" : type); + } + return type; } class RecordExtractor { @@ -175,8 +257,9 @@ class RecordExtractor { const SourceManager *const SM; std::vector includes; std::vector flags; - std::unordered_map visitedRecord; // Maps name to index in extractedRecords - std::vector extractedRecords; + std::unordered_map extractedRecords; + const std::string emptyStructType = "empty struct"; + const std::string autoTodo = "auto_todo"; unsigned int getCountedBy(const FieldDecl *const &field) { return field->getType()->isCountAttributedType() @@ -195,7 +278,7 @@ class RecordExtractor { if (recordDecl->isStruct()) { for (const auto &item : recordDecl->getAttrs()) { if (item->getKind() == clang::attr::Aligned) { - return "align[" + std::to_string(llvm::dyn_cast(item)->getAlignment(*context)) + "]"; + return "align[" + std::to_string(llvm::dyn_cast(item)->getAlignment(*context) / 8) + "]"; } else if (item->getKind() == clang::attr::Packed) { return "packed"; } @@ -208,16 +291,10 @@ class RecordExtractor { RecordExtractor(const SourceManager *const SM) : SM(SM){}; std::string getFieldType(const QualType &fieldType, ASTContext *context, const std::string &fieldName, const std::string &parent = "", bool isSyscallParam = false) { - if (contains(fieldName, "ifindex")) { - return "ifindex"; - } - const auto &field = fieldType.IgnoreParens(); - switch (field->getTypeClass()) { + const auto &field = fieldType.IgnoreParens().getDesugaredType(*context); + switch (fieldType.IgnoreParens()->getTypeClass()) { case clang::Type::Record: return extractRecord(field->getAsRecordDecl(), context, parent.empty() ? fieldName : parent + "_" + fieldName); - case clang::Type::Typedef: - return getFieldType(llvm::dyn_cast(field)->desugar(), context, fieldType.getAsString(), parent, - isSyscallParam); case clang::Type::IncompleteArray: // Defined as type[] return makeArray(getFieldType(llvm::dyn_cast(field)->getElementType(), context, fieldName)); case clang::Type::ConstantArray: { @@ -227,22 +304,27 @@ class RecordExtractor { case clang::Type::Pointer: { const auto &pointerType = llvm::dyn_cast(field); const auto &pointeeType = pointerType->getPointeeType(); - const auto &fieldType = - pointeeType->isAnyCharacterType() ? "string" : getFieldType(pointeeType, context, fieldName); + std::string fieldType; + if (pointeeType->isAnyCharacterType()) { + fieldType = stringSubtype(fieldName); + } else if (pointeeType->isVoidType()) { + fieldType = makeArray("int8"); + } else { + fieldType = getFieldType(pointeeType, context, fieldName); + } const auto &ptrDir = pointeeType.isConstQualified() ? "in" : "inout"; // TODO: Infer direction of non-const. return makePtr(ptrDir, fieldType, parent + "$auto_record" == fieldType); // Checks if the direct parent is the same as the node. } - case clang::Type::Builtin: { - const auto &type = getSyzType(field.getAsString(), isSyscallParam); - return isPadding(fieldName) ? makeConst(isSyscallParam ? "" : type) : type; - } - case clang::Type::CountAttributed: // Has the attribute counted_by. Handled by getCountedBy - return getFieldType(llvm::dyn_cast(field)->desugar(), context, fieldName); + case clang::Type::Builtin: + return getSyzType(field.getAsString(), fieldName, isSyscallParam); + case clang::Type::CountAttributed: // Has the attribute counted_by. Handled by getCountedBy case clang::Type::BTFTagAttributed: // Currently Unused - return getFieldType(llvm::dyn_cast(field)->desugar(), context, fieldName); + case clang::Type::Typedef: + return getFieldType(field, context, field.getAsString(), parent, isSyscallParam); case clang::Type::Elaborated: - return getFieldType(llvm::dyn_cast(field)->desugar(), context, fieldName, parent, isSyscallParam); + return getFieldType(llvm::dyn_cast(fieldType)->desugar(), context, fieldName, parent, + isSyscallParam); // NOTE: The fieldType contains information we need, don't use field instead. case clang::Type::Enum: { const auto &enumDecl = llvm::dyn_cast(field)->getDecl(); auto name = enumDecl->getNameAsString(); @@ -256,7 +338,7 @@ class RecordExtractor { return "flags[" + name + "]"; } case clang::Type::FunctionProto: - return makePtr("in", "int8"); + return makePtr("in", autoTodo); default: field->dump(); fprintf(stderr, "Unhandled field type %s\n", field->getTypeClassName()); @@ -264,18 +346,17 @@ class RecordExtractor { } } - std::string extractRecord(const RecordDecl *recordDecl, ASTContext *context, const std::string &backupName = "") { + std::string extractRecord(const RecordDecl *recordDecl, ASTContext *context, const std::string &backupName) { recordDecl = recordDecl->getDefinition(); if (!recordDecl) { // When the definition is in a different translation unit. - return "intptr"; + return autoTodo; } const auto &name = (recordDecl->getNameAsString().empty() ? backupName : recordDecl->getNameAsString()); const auto &recordName = name + "$auto_record"; - if (visitedRecord.find(name) != visitedRecord.end()) { // Don't extract the same record twice. + if (extractedRecords.find(name) != extractedRecords.end()) { // Don't extract the same record twice. return recordName; } - visitedRecord[name] = extractedRecords.size(); - extractedRecords.resize(extractedRecords.size() + 1); + extractedRecords[name]; bool isVarlen = false; std::vector members; for (const auto &field : recordDecl->fields()) { @@ -288,35 +369,35 @@ class RecordExtractor { fieldName = field->getNameAsString(); } const std::string &parentName = field->isAnonymousStructOrUnion() ? "" : name; - std::string fieldType = getFieldType(field->getType(), context, fieldName, parentName); - if (field->isUnnamedBitField()) { // pad - fieldType = makeConst(fieldType); - } else if (field->isBitField()) { - fieldType = fieldType + ":" + std::to_string(field->getBitWidthValue(*context)); - } - if (fieldType.empty()) { + const std::string &fieldType = + field->isBitField() ? getSyzType(field->getType().getAsString(), field->isUnnamedBitField() ? "" : fieldName, + false, field->getBitWidthValue(*context)) + : getFieldType(field->getType(), context, fieldName, parentName); + if (fieldType == emptyStructType) { continue; } - isVarlen |= isFieldVarlen(field->getType()) || (visitedRecord.find(fieldName) != visitedRecord.end() && - extractedRecords[visitedRecord[fieldName]].isVarlen); + isVarlen |= isFieldVarlen(field->getType()) || + (extractedRecords.find(fieldName) != extractedRecords.end() && + !extractedRecords[fieldName].name.empty() && extractedRecords[fieldName].isVarlen); members.push_back({fieldType, fieldName, getCountedBy(field)}); } if (members.empty()) { // Empty structs are not allowed in Syzlang. - return ""; + return emptyStructType; } - extractedRecords[visitedRecord[name]] = {recordName, std::move(members), getStructAttr(recordDecl, context), - recordDecl->isUnion(), isVarlen}; + extractedRecords[name] = {recordName, std::move(members), getStructAttr(recordDecl, context), recordDecl->isUnion(), + isVarlen}; return recordName; } void print() { + puts("type auto_todo intptr"); for (const auto &inc : includes) { printf("include<%s>\n", inc.c_str()); } for (const auto &flag : flags) { puts(flag.c_str()); } - for (auto &decl : extractedRecords) { + for (auto &[_, decl] : extractedRecords) { for (auto &member : decl.members) { if (member.countedBy != UINT_MAX) { auto &type = decl.members[member.countedBy].type; @@ -324,7 +405,7 @@ class RecordExtractor { } } } - for (const auto &decl : extractedRecords) { + for (const auto &[_, decl] : extractedRecords) { decl.print(); } } @@ -396,46 +477,6 @@ class NetlinkPolicyMatcher : public MatchFinder::MatchCallback { } } - const std::string nlaInt8Subtype(const std::string &name) { - if (endsWith(name, "ENABLED") || endsWith(name, "ENABLE")) { - return "bool8"; - } - return "int8"; - } - const std::string nlaInt16Subtype(const std::string &name) { - if (contains(name, "PORT")) { - return "sock_port"; - } - return "int16"; - } - const std::string nlaInt32Subtype(const std::string &name) { - if (contains(name, "IPV4")) { - return "ipv4_addr"; - } - if (endsWith(name, "FD")) { - if (endsWith(name, "NS_FD")) { - return "fd_namespace"; - } - return "fd"; - } - if (contains(name, "IFINDEX")) { - return "ifindex"; - } - if (endsWith(name, "ENABLED") || endsWith(name, "ENABLE")) { - return "bool32"; - } - return "int32"; - } - - const std::string nlaStringSubtype(const std::string &name) { - if (contains(name, "IFNAME") || endsWith(name, "DEV_NAME")) { - return "devname"; - } - return "string"; - } - - const std::string nlaInt64Subtype() { return "int64"; } - const std::string nlaArraySubtype(const std::string &name, const std::string &type, const size_t len, const std::string &typeOfLen) { if (!typeOfLen.empty()) { @@ -445,13 +486,10 @@ class NetlinkPolicyMatcher : public MatchFinder::MatchCallback { case 0: return makeArray("int8"); case 1: - return nlaInt8Subtype(name); case 2: - return nlaInt16Subtype(name); case 4: - return nlaInt32Subtype(name); case 8: - return nlaInt64Subtype(); + return intSubtype(name, IntType(len)); default: if (contains(name, "IPV6")) { return "ipv6_addr"; @@ -468,10 +506,10 @@ class NetlinkPolicyMatcher : public MatchFinder::MatchCallback { // TODO:Gather information from other defined fields to better specify a type. // Loosely based on https://elixir.bootlin.com/linux/v6.10/source/lib/nlattr.c if (type == "U8" || type == "S8") { - return nlaInt8Subtype(name); + return int8Subtype(name); } if (type == "U16" || type == "S16") { - return nlaInt16Subtype(name); + return int16Subtype(name); } if (type == "BINARY") { return nlaArraySubtype(name, type, len, typeOfLen); @@ -480,7 +518,7 @@ class NetlinkPolicyMatcher : public MatchFinder::MatchCallback { return "int16be"; } if (type == "U32" || type == "S32") { - return nlaInt32Subtype(name); + return int32Subtype(name); } if (type == "BE32") { return "int32be"; @@ -495,7 +533,7 @@ class NetlinkPolicyMatcher : public MatchFinder::MatchCallback { return "stringnoz"; } if (type == "NUL_STRING") { - return nlaStringSubtype(name); + return stringSubtype(name); } if (type == "BITFIELD32") { // TODO:Extract valued values from NLA_POLICY_BITFIELD32 macro. return "int32";