Skip to content

Spark SQL — DataFrame & Dataset

什么是 Spark SQL

Spark SQL 是 Spark 的结构化数据处理模块,提供了 DataFrameDataset API,以及标准 SQL 接口。相比 RDD,Spark SQL 通过 Catalyst 优化器Tungsten 执行引擎提供更好的性能。

核心优势

  • SQL 接口:直接写 SQL 查询
  • 自动优化:Catalyst 优化器自动优化执行计划
  • 列式内存:Tungsten 使用列式内存格式,减少 GC 压力
  • 与 Hive 集成:读写 Hive 表,兼容 HiveQL

DataFrame vs Dataset vs RDD

特性RDDDataFrameDataset
类型安全✅ 编译期❌ 运行期✅ 编译期
优化器✅ Catalyst✅ Catalyst
序列化Java/KryoTungstenTungsten
API 风格函数式SQL/DSL函数式+SQL
语言Scala/Java/Python/R全部Scala/Java

关系DataFrame = Dataset[Row]


SparkSession 创建

python
from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("SparkSQLDemo") \
    .master("local[*]") \
    .config("spark.sql.shuffle.partitions", "200") \
    .config("spark.sql.adaptive.enabled", "true") \
    .enableHiveSupport() \  # 启用 Hive 支持
    .getOrCreate()

DataFrame 操作

创建 DataFrame

python
# 从 Python 列表创建
from pyspark.sql import Row
from pyspark.sql.types import *

schema = StructType([
    StructField("id", IntegerType(), False),
    StructField("name", StringType(), True),
    StructField("age", IntegerType(), True),
    StructField("salary", DoubleType(), True),
])

data = [(1, "Alice", 28, 8000.0),
        (2, "Bob", 35, 12000.0),
        (3, "Charlie", 25, 6000.0)]

df = spark.createDataFrame(data, schema)

# 从 CSV 读取
df = spark.read \
    .option("header", "true") \
    .option("inferSchema", "true") \
    .csv("hdfs://namenode:9000/data/users.csv")

# 从 Parquet 读取
df = spark.read.parquet("hdfs://namenode:9000/data/users.parquet")

# 从 JSON 读取
df = spark.read.json("hdfs://namenode:9000/data/users.json")

# 从 Hive 表读取
df = spark.table("dw_db.dwd_user_log")

DataFrame 转换操作

python
from pyspark.sql import functions as F
from pyspark.sql.functions import col, lit, when, coalesce

# select:选择列
df.select("id", "name", "age")
df.select(col("id"), col("name"), (col("salary") * 1.1).alias("new_salary"))

# filter / where:过滤
df.filter(col("age") > 25)
df.where("age > 25 AND salary > 8000")

# withColumn:添加/修改列
df.withColumn("age_group",
    when(col("age") < 30, "young")
    .when(col("age") < 50, "middle")
    .otherwise("senior")
)

# groupBy + agg:分组聚合
df.groupBy("age_group").agg(
    F.count("id").alias("cnt"),
    F.avg("salary").alias("avg_salary"),
    F.max("salary").alias("max_salary"),
    F.sum("salary").alias("total_salary")
)

# orderBy / sort:排序
df.orderBy(col("salary").desc(), col("age").asc())

# join:连接
df1.join(df2, df1.id == df2.user_id, "left")
df1.join(df2, "user_id", "inner")  # 同名列 join

# distinct:去重
df.distinct()
df.dropDuplicates(["name", "age"])

# limit:限制行数
df.limit(100)

# union:合并(列名和类型必须一致)
df1.union(df2)
df1.unionByName(df2)  # 按列名合并(列顺序可以不同)

# drop:删除列
df.drop("salary", "age")

# rename:重命名列
df.withColumnRenamed("name", "user_name")

# fillna:填充空值
df.fillna({"age": 0, "salary": 0.0})

# dropna:删除含空值的行
df.dropna(subset=["name", "age"])

窗口函数

python
from pyspark.sql.window import Window

# 定义窗口
window_spec = Window.partitionBy("department") \
                    .orderBy(col("salary").desc())

# 排名函数
df.withColumn("rank", F.rank().over(window_spec)) \
  .withColumn("dense_rank", F.dense_rank().over(window_spec)) \
  .withColumn("row_number", F.row_number().over(window_spec))

# 聚合窗口
window_agg = Window.partitionBy("department")
df.withColumn("dept_avg_salary", F.avg("salary").over(window_agg)) \
  .withColumn("dept_total", F.sum("salary").over(window_agg))

# 滑动窗口(前3行到当前行)
window_slide = Window.partitionBy("user_id") \
                     .orderBy("event_time") \
                     .rowsBetween(-3, 0)
df.withColumn("rolling_avg", F.avg("value").over(window_slide))

# LAG / LEAD
window_lag = Window.partitionBy("user_id").orderBy("event_time")
df.withColumn("prev_value", F.lag("value", 1).over(window_lag)) \
  .withColumn("next_value", F.lead("value", 1).over(window_lag))

SQL 接口

python
# 注册临时视图
df.createOrReplaceTempView("users")

# 执行 SQL
result = spark.sql("""
    SELECT
        age_group,
        COUNT(*) AS cnt,
        AVG(salary) AS avg_salary,
        PERCENTILE_APPROX(salary, 0.5) AS median_salary
    FROM (
        SELECT *,
            CASE
                WHEN age < 30 THEN 'young'
                WHEN age < 50 THEN 'middle'
                ELSE 'senior'
            END AS age_group
        FROM users
    ) t
    GROUP BY age_group
    ORDER BY avg_salary DESC
""")

# 注册全局临时视图(跨 SparkSession 可见)
df.createOrReplaceGlobalTempView("global_users")
spark.sql("SELECT * FROM global_temp.global_users")

读写数据源

python
# 写入 Parquet(推荐格式)
df.write \
  .mode("overwrite") \
  .partitionBy("dt") \
  .parquet("hdfs://namenode:9000/output/users/")

# 写入 ORC
df.write.mode("append").orc("/output/users_orc/")

# 写入 CSV
df.write \
  .mode("overwrite") \
  .option("header", "true") \
  .csv("/output/users_csv/")

# 写入 Hive 表
df.write \
  .mode("overwrite") \
  .insertInto("dw_db.dws_user_daily")

# 写入 JDBC(MySQL)
df.write \
  .format("jdbc") \
  .option("url", "jdbc:mysql://localhost:3306/db") \
  .option("dbtable", "users") \
  .option("user", "root") \
  .option("password", "password") \
  .mode("append") \
  .save()

# 读取 JDBC
df = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:mysql://localhost:3306/db") \
    .option("dbtable", "(SELECT * FROM users WHERE age > 25) t") \
    .option("user", "root") \
    .option("password", "password") \
    .option("numPartitions", "10") \
    .option("partitionColumn", "id") \
    .option("lowerBound", "1") \
    .option("upperBound", "1000000") \
    .load()

Catalyst 优化器

Catalyst 是 Spark SQL 的查询优化器,执行以下优化:

SQL / DataFrame API


未解析逻辑计划(Unresolved Logical Plan)
      │ 分析(Analysis):解析列名、类型

逻辑计划(Logical Plan)
      │ 优化(Optimization):谓词下推、列裁剪、常量折叠

优化后逻辑计划(Optimized Logical Plan)
      │ 物理规划(Physical Planning):选择 Join 策略

物理计划(Physical Plan)
      │ 代码生成(Code Generation):Tungsten

执行

主要优化规则

  • 谓词下推(Predicate Pushdown):将 filter 下推到数据源,减少读取量
  • 列裁剪(Column Pruning):只读取需要的列
  • 常量折叠(Constant Folding):编译期计算常量表达式
  • Join 重排序:将小表放在 Join 的 Build 端

自适应查询执行(AQE)

Spark 3.0 引入 AQE,在运行时根据实际数据统计动态调整执行计划:

python
# 开启 AQE(Spark 3.2+ 默认开启)
spark.conf.set("spark.sql.adaptive.enabled", "true")

# 自动合并 Shuffle 后的小分区
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "128MB")

# 自动处理数据倾斜
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")

# 自动切换 Join 策略(大表 Join 变 Broadcast Join)
spark.conf.set("spark.sql.adaptive.localShuffleReader.enabled", "true")

性能调优

Broadcast Join

python
from pyspark.sql.functions import broadcast

# 小表广播(避免 Shuffle)
result = large_df.join(broadcast(small_df), "user_id")

# 自动广播阈值(默认 10MB)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "50MB")

分区优化

python
# 读取后重新分区(减少 Task 数量)
df = spark.read.parquet("/data/").repartition(200)

# 写入时按列分桶(减少后续 Join 的 Shuffle)
df.write \
  .bucketBy(32, "user_id") \
  .sortBy("user_id") \
  .saveAsTable("bucketed_users")

本站内容由 褚成志 整理编写,仅供学习参考