|
26 | 26 | from superset.exceptions import SupersetSecurityException |
27 | 27 | from superset.extensions import appbuilder |
28 | 28 | from superset.models.slice import Slice |
29 | | -from superset.security.manager import SupersetSecurityManager |
| 29 | +from superset.security.manager import query_context_modified, SupersetSecurityManager |
30 | 30 | from superset.sql_parse import Table |
31 | 31 | from superset.superset_typing import AdhocMetric |
32 | 32 | from superset.utils.core import override_user |
@@ -414,3 +414,120 @@ def test_raise_for_access_chart_owner( |
414 | 414 | sm.raise_for_access( |
415 | 415 | chart=slice, |
416 | 416 | ) |
| 417 | + |
| 418 | + |
| 419 | +def test_query_context_modified( |
| 420 | + mocker: MockFixture, |
| 421 | + stored_metrics: list[AdhocMetric], |
| 422 | +) -> None: |
| 423 | + """ |
| 424 | + Test the `query_context_modified` function. |
| 425 | +
|
| 426 | + The function is used to ensure guest users are not modifying the request payload on |
| 427 | + embedded dashboard, preventing users from modifying it to access metrics different |
| 428 | + from the ones stored in dashboard charts. |
| 429 | + """ |
| 430 | + query_context = mocker.MagicMock() |
| 431 | + query_context.slice_.id = 42 |
| 432 | + query_context.slice_.query_context = None |
| 433 | + query_context.slice_.params_dict = { |
| 434 | + "metrics": stored_metrics, |
| 435 | + } |
| 436 | + |
| 437 | + query_context.form_data = { |
| 438 | + "slice_id": 42, |
| 439 | + "metrics": stored_metrics, |
| 440 | + } |
| 441 | + query_context.queries = [QueryObject(metrics=stored_metrics)] # type: ignore |
| 442 | + assert not query_context_modified(query_context) |
| 443 | + |
| 444 | + |
| 445 | +def test_query_context_modified_tampered( |
| 446 | + mocker: MockFixture, |
| 447 | + stored_metrics: list[AdhocMetric], |
| 448 | +) -> None: |
| 449 | + """ |
| 450 | + Test the `query_context_modified` function when the request is tampered with. |
| 451 | +
|
| 452 | + The function is used to ensure guest users are not modifying the request payload on |
| 453 | + embedded dashboard, preventing users from modifying it to access metrics different |
| 454 | + from the ones stored in dashboard charts. |
| 455 | + """ |
| 456 | + query_context = mocker.MagicMock() |
| 457 | + query_context.slice_.id = 42 |
| 458 | + query_context.slice_.query_context = None |
| 459 | + query_context.slice_.params_dict = { |
| 460 | + "metrics": stored_metrics, |
| 461 | + } |
| 462 | + |
| 463 | + tampered_metrics = [ |
| 464 | + { |
| 465 | + "column": None, |
| 466 | + "expressionType": "SQL", |
| 467 | + "hasCustomLabel": False, |
| 468 | + "label": "COUNT(*) + 2", |
| 469 | + "sqlExpression": "COUNT(*) + 2", |
| 470 | + } |
| 471 | + ] |
| 472 | + |
| 473 | + query_context.form_data = { |
| 474 | + "slice_id": 42, |
| 475 | + "metrics": tampered_metrics, |
| 476 | + } |
| 477 | + query_context.queries = [QueryObject(metrics=tampered_metrics)] # type: ignore |
| 478 | + assert query_context_modified(query_context) |
| 479 | + |
| 480 | + |
| 481 | +def test_query_context_modified_native_filter(mocker: MockFixture) -> None: |
| 482 | + """ |
| 483 | + Test the `query_context_modified` function with a native filter request. |
| 484 | +
|
| 485 | + A native filter request has no chart (slice) associated with it. |
| 486 | + """ |
| 487 | + query_context = mocker.MagicMock() |
| 488 | + query_context.slice_ = None |
| 489 | + |
| 490 | + assert not query_context_modified(query_context) |
| 491 | + |
| 492 | + |
| 493 | +def test_query_context_modified_mixed_chart(mocker: MockFixture) -> None: |
| 494 | + """ |
| 495 | + Test the `query_context_modified` function for a mixed chart request. |
| 496 | +
|
| 497 | + The metrics in the mixed chart are a nested dictionary (due to `columns`), and need |
| 498 | + to be serialized to JSON with the keys sorted in order to compare the request |
| 499 | + metrics with the chart metrics. |
| 500 | + """ |
| 501 | + stored_metrics = [ |
| 502 | + { |
| 503 | + "optionName": "metric_vgops097wej_g8uff99zhk7", |
| 504 | + "label": "AVG(num)", |
| 505 | + "expressionType": "SIMPLE", |
| 506 | + "column": {"column_name": "num", "type": "BIGINT(20)"}, |
| 507 | + "aggregate": "AVG", |
| 508 | + } |
| 509 | + ] |
| 510 | + # different order (remember, dicts have order!) |
| 511 | + requested_metrics = [ |
| 512 | + { |
| 513 | + "aggregate": "AVG", |
| 514 | + "column": {"column_name": "num", "type": "BIGINT(20)"}, |
| 515 | + "expressionType": "SIMPLE", |
| 516 | + "label": "AVG(num)", |
| 517 | + "optionName": "metric_vgops097wej_g8uff99zhk7", |
| 518 | + } |
| 519 | + ] |
| 520 | + |
| 521 | + query_context = mocker.MagicMock() |
| 522 | + query_context.slice_.id = 42 |
| 523 | + query_context.slice_.query_context = None |
| 524 | + query_context.slice_.params_dict = { |
| 525 | + "metrics": stored_metrics, |
| 526 | + } |
| 527 | + |
| 528 | + query_context.form_data = { |
| 529 | + "slice_id": 42, |
| 530 | + "metrics": requested_metrics, |
| 531 | + } |
| 532 | + query_context.queries = [QueryObject(metrics=requested_metrics)] # type: ignore |
| 533 | + assert not query_context_modified(query_context) |
0 commit comments