Skip to content

Airflow — 工作流调度

什么是 Airflow

Apache Airflow 是一个工作流编排平台,用于以编程方式创建、调度和监控数据管道。它将工作流定义为 DAG(有向无环图),每个节点是一个任务,边表示依赖关系。

核心优势

  • 代码即配置:用 Python 定义工作流
  • 丰富的 Operator:内置 Spark、Hive、Bash、HTTP 等
  • 可视化监控:Web UI 查看 DAG 状态和日志
  • 动态 DAG:根据参数动态生成任务
  • 重试与告警:自动重试失败任务,发送告警

核心概念

概念说明
DAG有向无环图,定义工作流的任务和依赖关系
TaskDAG 中的一个节点,代表一个工作单元
OperatorTask 的类型(BashOperator、PythonOperator 等)
DAG RunDAG 的一次执行实例
Task InstanceTask 在某次 DAG Run 中的执行实例
Scheduler定期扫描 DAG,触发到期的 DAG Run
Executor执行 Task 的组件(LocalExecutor、CeleryExecutor 等)
XComTask 间传递数据的机制

第一个 DAG

python
from datetime import datetime, timedelta
from airflow import DAG
from airflow.operators.bash import BashOperator
from airflow.operators.python import PythonOperator

# DAG 默认参数
default_args = {
    'owner': 'data-team',
    'depends_on_past': False,
    'start_date': datetime(2024, 1, 1),
    'email': ['alert@example.com'],
    'email_on_failure': True,
    'email_on_retry': False,
    'retries': 3,
    'retry_delay': timedelta(minutes=5),
}

# 定义 DAG
with DAG(
    dag_id='daily_etl_pipeline',
    default_args=default_args,
    description='每日 ETL 数据管道',
    schedule_interval='0 2 * * *',  # 每天凌晨 2 点
    catchup=False,                   # 不补跑历史
    tags=['etl', 'daily'],
) as dag:

    # 任务 1:从 MySQL 抽取数据
    extract = BashOperator(
        task_id='extract_from_mysql',
        bash_command="""
            sqoop import \
              --connect jdbc:mysql://mysql:3306/shop \
              --username root --password password \
              --table orders \
              --target-dir /data/ods/orders/dt={{ ds }} \
              --delete-target-dir \
              --num-mappers 4
        """,
    )

    # 任务 2:Spark 数据清洗
    def run_spark_job(**context):
        dt = context['ds']  # 执行日期,格式 2024-01-01
        import subprocess
        result = subprocess.run([
            'spark-submit',
            '--master', 'yarn',
            '--class', 'com.example.CleanJob',
            '/opt/jars/etl.jar',
            dt
        ], capture_output=True, text=True)
        if result.returncode != 0:
            raise Exception(f"Spark job failed: {result.stderr}")

    clean = PythonOperator(
        task_id='clean_with_spark',
        python_callable=run_spark_job,
        provide_context=True,
    )

    # 任务 3:Hive 聚合
    aggregate = BashOperator(
        task_id='aggregate_in_hive',
        bash_command="""
            hive -e "
                INSERT OVERWRITE TABLE dws_order_daily PARTITION (dt='{{ ds }}')
                SELECT user_id, COUNT(*) AS order_cnt, SUM(amount) AS total
                FROM dwd_orders
                WHERE dt = '{{ ds }}'
                GROUP BY user_id;
            "
        """,
    )

    # 任务 4:数据质量检查
    def check_data_quality(**context):
        dt = context['ds']
        # 检查数据量是否合理
        from pyhive import hive
        conn = hive.Connection(host='hive-server', port=10000)
        cursor = conn.cursor()
        cursor.execute(f"SELECT COUNT(*) FROM dws_order_daily WHERE dt='{dt}'")
        count = cursor.fetchone()[0]
        if count < 1000:
            raise ValueError(f"Data quality check failed: only {count} rows")

    quality_check = PythonOperator(
        task_id='data_quality_check',
        python_callable=check_data_quality,
        provide_context=True,
    )

    # 定义依赖关系
    extract >> clean >> aggregate >> quality_check

常用 Operator

BashOperator

python
from airflow.operators.bash import BashOperator

task = BashOperator(
    task_id='run_script',
    bash_command='python /opt/scripts/process.py --date {{ ds }}',
    env={'PYTHONPATH': '/opt/lib'},
)

PythonOperator

python
from airflow.operators.python import PythonOperator

def my_function(ds, **kwargs):
    print(f"Processing date: {ds}")
    return "success"

task = PythonOperator(
    task_id='python_task',
    python_callable=my_function,
    op_kwargs={'extra_param': 'value'},
)

SparkSubmitOperator

python
from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator

spark_task = SparkSubmitOperator(
    task_id='spark_job',
    application='/opt/jars/myapp.jar',
    conn_id='spark_default',
    java_class='com.example.MyJob',
    application_args=['--date', '{{ ds }}'],
    conf={
        'spark.executor.memory': '8g',
        'spark.executor.cores': '4',
        'spark.num.executors': '10',
    },
    executor_memory='8g',
    driver_memory='4g',
)

HiveOperator

python
from airflow.providers.apache.hive.operators.hive import HiveOperator

hive_task = HiveOperator(
    task_id='hive_query',
    hql="""
        INSERT OVERWRITE TABLE result PARTITION (dt='{{ ds }}')
        SELECT user_id, COUNT(*) AS cnt
        FROM source_table
        WHERE dt = '{{ ds }}'
        GROUP BY user_id;
    """,
    hive_cli_conn_id='hive_default',
)

HttpSensor(等待 HTTP 接口就绪)

python
from airflow.providers.http.sensors.http import HttpSensor

wait_for_api = HttpSensor(
    task_id='wait_for_api',
    http_conn_id='my_api',
    endpoint='/health',
    response_check=lambda response: response.json()['status'] == 'ok',
    poke_interval=30,
    timeout=300,
)

XCom(任务间传递数据)

python
def push_data(**context):
    # 推送数据
    context['ti'].xcom_push(key='row_count', value=12345)
    return "pushed"  # return 值自动推送为 return_value

def pull_data(**context):
    # 拉取数据
    count = context['ti'].xcom_pull(
        task_ids='push_task',
        key='row_count'
    )
    print(f"Row count: {count}")

push_task = PythonOperator(task_id='push_task', python_callable=push_data)
pull_task = PythonOperator(task_id='pull_task', python_callable=pull_data)

push_task >> pull_task

动态 DAG

python
# 根据配置动态生成任务
tables = ['orders', 'users', 'products', 'payments']

with DAG('dynamic_etl', schedule_interval='@daily', ...) as dag:
    start = DummyOperator(task_id='start')
    end = DummyOperator(task_id='end')

    for table in tables:
        extract = BashOperator(
            task_id=f'extract_{table}',
            bash_command=f'python extract.py --table {table} --date {{{{ ds }}}}',
        )
        transform = BashOperator(
            task_id=f'transform_{table}',
            bash_command=f'python transform.py --table {table} --date {{{{ ds }}}}',
        )
        start >> extract >> transform >> end

变量与连接

python
from airflow.models import Variable
from airflow.hooks.base import BaseHook

# 读取 Airflow 变量(在 UI 中配置)
db_host = Variable.get("mysql_host")
config = Variable.get("etl_config", deserialize_json=True)

# 使用连接(在 UI 中配置)
conn = BaseHook.get_connection("mysql_default")
print(f"Host: {conn.host}, Port: {conn.port}")

触发规则

python
from airflow.utils.trigger_rule import TriggerRule

# 默认:所有上游成功才执行
task_default = PythonOperator(
    task_id='default',
    trigger_rule=TriggerRule.ALL_SUCCESS,  # 默认
    ...
)

# 任意上游成功就执行
task_any = PythonOperator(
    task_id='any_success',
    trigger_rule=TriggerRule.ONE_SUCCESS,
    ...
)

# 所有上游完成(无论成功失败)
task_all_done = PythonOperator(
    task_id='cleanup',
    trigger_rule=TriggerRule.ALL_DONE,
    ...
)

# 任意上游失败就执行(告警任务)
task_on_fail = PythonOperator(
    task_id='send_alert',
    trigger_rule=TriggerRule.ONE_FAILED,
    ...
)

常用命令

bash
# 初始化数据库
airflow db init

# 启动 Web 服务
airflow webserver --port 8080

# 启动调度器
airflow scheduler

# 列出 DAG
airflow dags list

# 触发 DAG(手动执行)
airflow dags trigger daily_etl_pipeline --exec-date 2024-01-01

# 查看 DAG 状态
airflow dags state daily_etl_pipeline 2024-01-01

# 测试单个 Task
airflow tasks test daily_etl_pipeline extract_from_mysql 2024-01-01

# 清除 Task 状态(重新执行)
airflow tasks clear daily_etl_pipeline -t extract_from_mysql -s 2024-01-01 -e 2024-01-01

# 暂停/恢复 DAG
airflow dags pause daily_etl_pipeline
airflow dags unpause daily_etl_pipeline

本站内容由 褚成志 整理编写,仅供学习参考