Spark SQL — DataFrame & Dataset
什么是 Spark SQL
Spark SQL 是 Spark 的结构化数据处理模块,提供了 DataFrame 和 Dataset API,以及标准 SQL 接口。相比 RDD,Spark SQL 通过 Catalyst 优化器和 Tungsten 执行引擎提供更好的性能。
核心优势:
- SQL 接口:直接写 SQL 查询
- 自动优化:Catalyst 优化器自动优化执行计划
- 列式内存:Tungsten 使用列式内存格式,减少 GC 压力
- 与 Hive 集成:读写 Hive 表,兼容 HiveQL
DataFrame vs Dataset vs RDD
| 特性 | RDD | DataFrame | Dataset |
|---|---|---|---|
| 类型安全 | ✅ 编译期 | ❌ 运行期 | ✅ 编译期 |
| 优化器 | ❌ | ✅ Catalyst | ✅ Catalyst |
| 序列化 | Java/Kryo | Tungsten | Tungsten |
| API 风格 | 函数式 | SQL/DSL | 函数式+SQL |
| 语言 | Scala/Java/Python/R | 全部 | Scala/Java |
关系:DataFrame = Dataset[Row]
SparkSession 创建
python
from pyspark.sql import SparkSession
spark = SparkSession.builder \
.appName("SparkSQLDemo") \
.master("local[*]") \
.config("spark.sql.shuffle.partitions", "200") \
.config("spark.sql.adaptive.enabled", "true") \
.enableHiveSupport() \ # 启用 Hive 支持
.getOrCreate()DataFrame 操作
创建 DataFrame
python
# 从 Python 列表创建
from pyspark.sql import Row
from pyspark.sql.types import *
schema = StructType([
StructField("id", IntegerType(), False),
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
StructField("salary", DoubleType(), True),
])
data = [(1, "Alice", 28, 8000.0),
(2, "Bob", 35, 12000.0),
(3, "Charlie", 25, 6000.0)]
df = spark.createDataFrame(data, schema)
# 从 CSV 读取
df = spark.read \
.option("header", "true") \
.option("inferSchema", "true") \
.csv("hdfs://namenode:9000/data/users.csv")
# 从 Parquet 读取
df = spark.read.parquet("hdfs://namenode:9000/data/users.parquet")
# 从 JSON 读取
df = spark.read.json("hdfs://namenode:9000/data/users.json")
# 从 Hive 表读取
df = spark.table("dw_db.dwd_user_log")DataFrame 转换操作
python
from pyspark.sql import functions as F
from pyspark.sql.functions import col, lit, when, coalesce
# select:选择列
df.select("id", "name", "age")
df.select(col("id"), col("name"), (col("salary") * 1.1).alias("new_salary"))
# filter / where:过滤
df.filter(col("age") > 25)
df.where("age > 25 AND salary > 8000")
# withColumn:添加/修改列
df.withColumn("age_group",
when(col("age") < 30, "young")
.when(col("age") < 50, "middle")
.otherwise("senior")
)
# groupBy + agg:分组聚合
df.groupBy("age_group").agg(
F.count("id").alias("cnt"),
F.avg("salary").alias("avg_salary"),
F.max("salary").alias("max_salary"),
F.sum("salary").alias("total_salary")
)
# orderBy / sort:排序
df.orderBy(col("salary").desc(), col("age").asc())
# join:连接
df1.join(df2, df1.id == df2.user_id, "left")
df1.join(df2, "user_id", "inner") # 同名列 join
# distinct:去重
df.distinct()
df.dropDuplicates(["name", "age"])
# limit:限制行数
df.limit(100)
# union:合并(列名和类型必须一致)
df1.union(df2)
df1.unionByName(df2) # 按列名合并(列顺序可以不同)
# drop:删除列
df.drop("salary", "age")
# rename:重命名列
df.withColumnRenamed("name", "user_name")
# fillna:填充空值
df.fillna({"age": 0, "salary": 0.0})
# dropna:删除含空值的行
df.dropna(subset=["name", "age"])窗口函数
python
from pyspark.sql.window import Window
# 定义窗口
window_spec = Window.partitionBy("department") \
.orderBy(col("salary").desc())
# 排名函数
df.withColumn("rank", F.rank().over(window_spec)) \
.withColumn("dense_rank", F.dense_rank().over(window_spec)) \
.withColumn("row_number", F.row_number().over(window_spec))
# 聚合窗口
window_agg = Window.partitionBy("department")
df.withColumn("dept_avg_salary", F.avg("salary").over(window_agg)) \
.withColumn("dept_total", F.sum("salary").over(window_agg))
# 滑动窗口(前3行到当前行)
window_slide = Window.partitionBy("user_id") \
.orderBy("event_time") \
.rowsBetween(-3, 0)
df.withColumn("rolling_avg", F.avg("value").over(window_slide))
# LAG / LEAD
window_lag = Window.partitionBy("user_id").orderBy("event_time")
df.withColumn("prev_value", F.lag("value", 1).over(window_lag)) \
.withColumn("next_value", F.lead("value", 1).over(window_lag))SQL 接口
python
# 注册临时视图
df.createOrReplaceTempView("users")
# 执行 SQL
result = spark.sql("""
SELECT
age_group,
COUNT(*) AS cnt,
AVG(salary) AS avg_salary,
PERCENTILE_APPROX(salary, 0.5) AS median_salary
FROM (
SELECT *,
CASE
WHEN age < 30 THEN 'young'
WHEN age < 50 THEN 'middle'
ELSE 'senior'
END AS age_group
FROM users
) t
GROUP BY age_group
ORDER BY avg_salary DESC
""")
# 注册全局临时视图(跨 SparkSession 可见)
df.createOrReplaceGlobalTempView("global_users")
spark.sql("SELECT * FROM global_temp.global_users")读写数据源
python
# 写入 Parquet(推荐格式)
df.write \
.mode("overwrite") \
.partitionBy("dt") \
.parquet("hdfs://namenode:9000/output/users/")
# 写入 ORC
df.write.mode("append").orc("/output/users_orc/")
# 写入 CSV
df.write \
.mode("overwrite") \
.option("header", "true") \
.csv("/output/users_csv/")
# 写入 Hive 表
df.write \
.mode("overwrite") \
.insertInto("dw_db.dws_user_daily")
# 写入 JDBC(MySQL)
df.write \
.format("jdbc") \
.option("url", "jdbc:mysql://localhost:3306/db") \
.option("dbtable", "users") \
.option("user", "root") \
.option("password", "password") \
.mode("append") \
.save()
# 读取 JDBC
df = spark.read \
.format("jdbc") \
.option("url", "jdbc:mysql://localhost:3306/db") \
.option("dbtable", "(SELECT * FROM users WHERE age > 25) t") \
.option("user", "root") \
.option("password", "password") \
.option("numPartitions", "10") \
.option("partitionColumn", "id") \
.option("lowerBound", "1") \
.option("upperBound", "1000000") \
.load()Catalyst 优化器
Catalyst 是 Spark SQL 的查询优化器,执行以下优化:
SQL / DataFrame API
│
▼
未解析逻辑计划(Unresolved Logical Plan)
│ 分析(Analysis):解析列名、类型
▼
逻辑计划(Logical Plan)
│ 优化(Optimization):谓词下推、列裁剪、常量折叠
▼
优化后逻辑计划(Optimized Logical Plan)
│ 物理规划(Physical Planning):选择 Join 策略
▼
物理计划(Physical Plan)
│ 代码生成(Code Generation):Tungsten
▼
执行主要优化规则:
- 谓词下推(Predicate Pushdown):将 filter 下推到数据源,减少读取量
- 列裁剪(Column Pruning):只读取需要的列
- 常量折叠(Constant Folding):编译期计算常量表达式
- Join 重排序:将小表放在 Join 的 Build 端
自适应查询执行(AQE)
Spark 3.0 引入 AQE,在运行时根据实际数据统计动态调整执行计划:
python
# 开启 AQE(Spark 3.2+ 默认开启)
spark.conf.set("spark.sql.adaptive.enabled", "true")
# 自动合并 Shuffle 后的小分区
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "128MB")
# 自动处理数据倾斜
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")
# 自动切换 Join 策略(大表 Join 变 Broadcast Join)
spark.conf.set("spark.sql.adaptive.localShuffleReader.enabled", "true")性能调优
Broadcast Join
python
from pyspark.sql.functions import broadcast
# 小表广播(避免 Shuffle)
result = large_df.join(broadcast(small_df), "user_id")
# 自动广播阈值(默认 10MB)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "50MB")分区优化
python
# 读取后重新分区(减少 Task 数量)
df = spark.read.parquet("/data/").repartition(200)
# 写入时按列分桶(减少后续 Join 的 Shuffle)
df.write \
.bucketBy(32, "user_id") \
.sortBy("user_id") \
.saveAsTable("bucketed_users")