diff --git a/docs/benefits-over-pyright/fixed-context-manager-exit-types.md b/docs/benefits-over-pyright/fixed-context-manager-exit-types.md
new file mode 100644
index 000000000..212e9cb23
--- /dev/null
+++ b/docs/benefits-over-pyright/fixed-context-manager-exit-types.md
@@ -0,0 +1,140 @@
+# fixed handling for context managers that can suppress exceptions
+
+## the problem
+
+if an exception is raised inside a context manager and its `__exit__` method returns `True`, it will be suppressed:
+
+```py
+class SuppressError(AbstractContextManager[None, bool]):
+ @override
+ def __enter__(self) -> None:
+ pass
+
+ @override
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ traceback: TracebackType | None,
+ /,
+ ) -> bool:
+ return True
+```
+
+but if it returns `False` or `None`, the exception will not be suppressed:
+
+```py
+class Log(AbstractContextManager[None, Literal[False]]):
+ @override
+ def __enter__(self) -> None:
+ print("entering context manager")
+
+ @override
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ traceback: TracebackType | None,
+ /,
+ ) -> Literal[False]:
+ print("exiting context manager")
+ return False
+```
+
+pyright will take this into account when determining reachability:
+
+```py
+def raise_exception() -> Never:
+ raise Exception
+
+with SuppressError():
+ foo: int = raise_exception()
+
+# when the exception is raised, the context manager exits before foo is assigned to:
+print(foo) # error: "foo" is unbound (reportPossiblyUnboundVariable)
+```
+
+```py
+with Log():
+ foo: int = raise_exception()
+
+# when the exception is raised, it does not get suppressed so this line can never run:
+print(foo) # error: Code is unreachable (reportUnreachable)
+```
+
+however, due to [a bug in mypy](https://github.com/python/mypy/issues/8766) that [pyright blindly copied and accepted as the "standard"](https://github.com/microsoft/pyright/issues/6034#issuecomment-1738941412), a context manager will incorrectly be treated as if it never suppresses exceptions if its return type is a union of `bool | None`:
+
+```py
+class SuppressError(AbstractContextManager[None, bool | None]):
+ @override
+ def __enter__(self) -> None:
+ pass
+
+ @override
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ traceback: TracebackType | None,
+ /,
+ ) -> bool | None:
+ return True
+
+
+with SuppressError():
+ foo: int = raise_exception()
+
+# this error is wrong because this line is actually reached at runtime:
+print(foo) # error: Code is unreachable (reportUnreachable)
+```
+
+## the solution
+
+basedpyright introduces a new setting, `strictContextManagerExitTypes` to address this issue. when enabled, context managers where the `__exit__` dunder returns `bool | None` are treated the same way as context managers that return `bool` or `Literal[True]`. put simply, if `True` is assignable to the return type, then it's treated as if it can suppress exceptions.
+
+## issues with `@contextmanager`
+
+the reason we support disabling this fix using the `strictContextManagerExitTypes` setting is because it will cause all context managers decorated with `@contextlib.contextmanager` to be treated as if they can suppress an exception, even if they never do:
+
+```py
+@contextmanager
+def log():
+ print("entering context manager")
+ try:
+ yield
+ finally:
+ print("exiting context manager")
+
+with log():
+ foo: int = get_value()
+
+# basedpyright accounts for the possibility that get_value raised an exception and foo
+# was never assigned to, even though this context manager never suppresses exceptions
+print(foo) # error: "foo" is unbound (reportPossiblyUnboundVariable)
+```
+
+this is because there's no way to tell a type checker whether the function body contains a `try`/`except` statement, which is necessary to suppress exeptions when using the `@contextmanager` decorator:
+
+```py
+@contextmanager
+def suppress_error():
+ try:
+ yield
+ except:
+ pass
+```
+
+as a workaround, it's recommended to instead use class context managers [like in the examples above](#the-problem) for the following reasons:
+
+- it forces you to be explicit about whether or not the context manager is able to suppress an exception
+- it prevents you from accidentally creating a context manager that doesn't run its cleanup if an exception occurs:
+ ```py
+ @contextmanager
+ def suppress_error():
+ print("setup")
+ yield
+ # this part won't run if an exception is raised because you forgot to use a try/finally
+ print("cleanup")
+ ```
+
+if you're dealing with third party modules where the usage of `@contextmanager` decorator is unavoidable, it may be best to just disable `strictContextManagerExitTypes` instead.
diff --git a/docs/configuration.md b/docs/configuration.md
index d87a42bb9..71cc2ea9a 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -77,6 +77,8 @@ The following settings determine how different types should be evaluated.
- **strictGenericNarrowing** [boolean]: When a type is narrowed in such a way that its type parameters are not known (eg. using an `isinstance` check), basedpyright will resolve the type parameter to the generic's bound or constraint instead of `Any`. [more info](../benefits-over-pyright/improved-generic-narrowing.md)
+- **strictContextManagerExitTypes** [boolean]: Assume that a context manager could potentially suppress an exception if its `__exit__` method is typed as returning `bool | None`. [more info](../benefits-over-pyright/fixed-context-manager-exit-types.md)
+
## Diagnostic Categories
diagnostics can be configured to be reported as any of the following categories:
diff --git a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts
index 88fea551f..10d050016 100644
--- a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts
+++ b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts
@@ -1968,7 +1968,10 @@ export function getCodeFlowEngine(
// valid return types here are `bool | None`. if the context manager returns `True` then it suppresses,
// meaning we only know for sure that the context manager can't swallow exceptions if its return type
// does not allow `True`.
- const typesToCheck = isUnion(returnType) ? returnType.priv.subtypes : [returnType];
+ const typesToCheck =
+ getFileInfo(node).diagnosticRuleSet.strictContextManagerExitTypes && isUnion(returnType)
+ ? returnType.priv.subtypes
+ : [returnType];
const boolType = typesToCheck.find(
(type): type is ClassType => isClassInstance(type) && ClassType.isBuiltIn(type, 'bool')
);
diff --git a/packages/pyright-internal/src/common/configOptions.ts b/packages/pyright-internal/src/common/configOptions.ts
index fb7795579..1f4439e54 100644
--- a/packages/pyright-internal/src/common/configOptions.ts
+++ b/packages/pyright-internal/src/common/configOptions.ts
@@ -412,6 +412,7 @@ export interface DiagnosticRuleSet {
*/
failOnWarnings: boolean;
strictGenericNarrowing: boolean;
+ strictContextManagerExitTypes: boolean;
reportUnreachable: DiagnosticLevel;
reportAny: DiagnosticLevel;
reportExplicitAny: DiagnosticLevel;
@@ -443,6 +444,7 @@ export function getBooleanDiagnosticRules(includeNonOverridable = false) {
DiagnosticRule.deprecateTypingAliases,
DiagnosticRule.disableBytesTypePromotions,
DiagnosticRule.strictGenericNarrowing,
+ DiagnosticRule.strictContextManagerExitTypes,
];
if (includeNonOverridable) {
@@ -676,6 +678,7 @@ export function getOffDiagnosticRuleSet(): DiagnosticRuleSet {
reportImplicitOverride: 'none',
failOnWarnings: false,
strictGenericNarrowing: false,
+ strictContextManagerExitTypes: false,
reportUnreachable: 'hint',
reportAny: 'none',
reportExplicitAny: 'none',
@@ -792,6 +795,7 @@ export function getBasicDiagnosticRuleSet(): DiagnosticRuleSet {
reportImplicitOverride: 'none',
failOnWarnings: false,
strictGenericNarrowing: false,
+ strictContextManagerExitTypes: false,
reportUnreachable: 'hint',
reportAny: 'none',
reportExplicitAny: 'none',
@@ -908,6 +912,7 @@ export function getStandardDiagnosticRuleSet(): DiagnosticRuleSet {
reportImplicitOverride: 'none',
failOnWarnings: false,
strictGenericNarrowing: false,
+ strictContextManagerExitTypes: false,
reportUnreachable: 'hint',
reportAny: 'none',
reportExplicitAny: 'none',
@@ -1023,6 +1028,7 @@ export const getRecommendedDiagnosticRuleSet = (): DiagnosticRuleSet => ({
reportImplicitOverride: 'warning',
failOnWarnings: true,
strictGenericNarrowing: true,
+ strictContextManagerExitTypes: true,
reportUnreachable: 'warning',
reportAny: 'warning',
reportExplicitAny: 'warning',
@@ -1135,6 +1141,7 @@ export const getAllDiagnosticRuleSet = (): DiagnosticRuleSet => ({
reportImplicitOverride: 'error',
failOnWarnings: true,
strictGenericNarrowing: true,
+ strictContextManagerExitTypes: true,
reportUnreachable: 'error',
reportAny: 'error',
reportExplicitAny: 'error',
@@ -1248,6 +1255,7 @@ export function getStrictDiagnosticRuleSet(): DiagnosticRuleSet {
reportImplicitOverride: 'none',
failOnWarnings: false,
strictGenericNarrowing: false,
+ strictContextManagerExitTypes: false,
reportUnreachable: 'hint',
reportAny: 'none',
reportExplicitAny: 'none',
diff --git a/packages/pyright-internal/src/common/diagnosticRules.ts b/packages/pyright-internal/src/common/diagnosticRules.ts
index a6d564d52..9245a2981 100644
--- a/packages/pyright-internal/src/common/diagnosticRules.ts
+++ b/packages/pyright-internal/src/common/diagnosticRules.ts
@@ -107,6 +107,7 @@ export enum DiagnosticRule {
// basedpyright options:
failOnWarnings = 'failOnWarnings',
strictGenericNarrowing = 'strictGenericNarrowing',
+ strictContextManagerExitTypes = 'strictContextManagerExitTypes',
reportUnreachable = 'reportUnreachable',
reportAny = 'reportAny',
reportExplicitAny = 'reportExplicitAny',
diff --git a/packages/pyright-internal/src/tests/checker.test.ts b/packages/pyright-internal/src/tests/checker.test.ts
index 020cbbc92..04308efea 100644
--- a/packages/pyright-internal/src/tests/checker.test.ts
+++ b/packages/pyright-internal/src/tests/checker.test.ts
@@ -170,13 +170,30 @@ test('With2', () => {
TestUtils.validateResults(analysisResults, 3);
});
-test('context manager where __exit__ returns bool | None', () => {
- const analysisResults = TestUtils.typeAnalyzeSampleFiles(['withBased.py']);
- TestUtils.validateResultsButBased(analysisResults, {
- hints: [
- { code: DiagnosticRule.reportUnreachable, line: 45 },
- { code: DiagnosticRule.reportUnreachable, line: 60 },
- ],
+describe('context manager where __exit__ returns bool | None', () => {
+ test('strictContextManagerExitTypes=true', () => {
+ const configOptions = new ConfigOptions(Uri.empty());
+ configOptions.diagnosticRuleSet.strictContextManagerExitTypes = true;
+ const analysisResults = TestUtils.typeAnalyzeSampleFiles(['withBased.py'], configOptions);
+ TestUtils.validateResultsButBased(analysisResults, {
+ hints: [
+ { code: DiagnosticRule.reportUnreachable, line: 45 },
+ { code: DiagnosticRule.reportUnreachable, line: 60 },
+ ],
+ });
+ });
+ test('strictContextManagerExitTypes=false', () => {
+ const configOptions = new ConfigOptions(Uri.empty());
+ configOptions.diagnosticRuleSet.strictContextManagerExitTypes = false;
+ const analysisResults = TestUtils.typeAnalyzeSampleFiles(['withBased.py'], configOptions);
+ TestUtils.validateResultsButBased(analysisResults, {
+ hints: [
+ { code: DiagnosticRule.reportUnreachable, line: 16 },
+ { code: DiagnosticRule.reportUnreachable, line: 30 },
+ { code: DiagnosticRule.reportUnreachable, line: 45 },
+ { code: DiagnosticRule.reportUnreachable, line: 60 },
+ ],
+ });
});
});