Spark RDD Operations in Scala: Complete Guide

TT

Spark RDD Operations in Scala: Complete Guide

RDDs (Resilient Distributed Datasets) are the foundational abstraction in Apache Spark. While DataFrames are preferred for most production work, understanding RDDs is essential — they underpin every Spark computation and are required for low-level control. This module covers RDD creation, transformations, actions, and persistence.

What Is an RDD?

An RDD is:

  • Resilient — it can be recomputed if a partition is lost
  • Distributed — data is split across executor nodes
  • Dataset — a collection of elements (any Scala type)

RDDs are immutable and lazily evaluated. Every transformation creates a new RDD without modifying the original.

Creating RDDs

scala
import org.apache.spark.sql.SparkSession

val spark = SparkSession.builder().master("local[*]").appName("RDD").getOrCreate()
val sc = spark.sparkContext

// From a local collection
val rdd1 = sc.parallelize(List(1, 2, 3, 4, 5))
val rdd2 = sc.parallelize(List(1, 2, 3, 4, 5), numSlices = 4)  // with 4 partitions

// From a text file (each line becomes an element)
val lines = sc.textFile("/path/to/file.txt")

// From multiple files (glob)
val multiFiles = sc.textFile("/path/to/data/*.txt")

// From a directory
val dir = sc.textFile("/path/to/directory/")

Key Transformations

Transformations are lazy — they define the computation graph but don't execute until an action is called.

map and flatMap

scala
val nums = sc.parallelize(1 to 10)

val doubled = nums.map(_ * 2)
// RDD[Int]: 2, 4, 6, 8, 10, 12, 14, 16, 18, 20

val words = sc.parallelize(List("hello world", "spark rdd"))
val allWords = words.flatMap(_.split(" "))
// RDD[String]: hello, world, spark, rdd

filter

scala
val evens = nums.filter(_ % 2 == 0)
// RDD[Int]: 2, 4, 6, 8, 10

mapPartitions

More efficient than map when you need to initialize a resource once per partition (like a database connection):

scala
val result = nums.mapPartitions { iter =>
  // This block runs once per partition, not once per element
  iter.map(_ * 2)
}

distinct, union, intersection, subtract

scala
val a = sc.parallelize(List(1, 2, 3, 3, 4))
val b = sc.parallelize(List(3, 4, 5))

println(a.distinct().collect().toList)        // List(1, 2, 3, 4)
println(a.union(b).collect().toList)          // List(1, 2, 3, 3, 4, 3, 4, 5)
println(a.intersection(b).collect().toList)   // List(3, 4)
println(a.subtract(b).collect().toList)       // List(1, 2)

Key-Value RDD Transformations (Pair RDDs)

When an RDD contains tuples (K, V), it becomes a Pair RDD with extra operations:

scala
val pairs = sc.parallelize(List(
  ("alice", 85), ("bob", 72), ("alice", 90), ("bob", 88), ("carol", 95)
))

// reduceByKey — aggregate values for each key
val totals = pairs.reduceByKey(_ + _)
println(totals.collect().toList)
// List((alice, 175), (bob, 160), (carol, 95))

// groupByKey — group all values by key (less efficient than reduceByKey)
val grouped = pairs.groupByKey()
// List((alice, [85, 90]), (bob, [72, 88]), (carol, [95]))

// mapValues — transform values without changing keys
val doubled = pairs.mapValues(_ * 2)

// sortByKey
val sorted = pairs.sortByKey()

// countByKey — count occurrences of each key (action, returns Map)
val counts = pairs.countByKey()
println(counts)  // Map(alice -> 2, bob -> 2, carol -> 1)

join Operations

scala
val names = sc.parallelize(List((1, "Alice"), (2, "Bob"), (3, "Carol")))
val scores = sc.parallelize(List((1, 95), (2, 87), (4, 78)))

// Inner join
val joined = names.join(scores)
println(joined.collect().toList)
// List((1,(Alice,95)), (2,(Bob,87)))

// Left outer join
val leftJoined = names.leftOuterJoin(scores)
// List((1,(Alice,Some(95))), (2,(Bob,Some(87))), (3,(Carol,None)))

Actions

Actions trigger execution and return a result to the driver:

scala
val rdd = sc.parallelize(1 to 100)

println(rdd.count())                 // 100
println(rdd.sum())                   // 5050.0
println(rdd.min())                   // 1
println(rdd.max())                   // 100
println(rdd.mean())                  // 50.5
println(rdd.first())                 // 1
println(rdd.take(5).toList)          // List(1, 2, 3, 4, 5)
println(rdd.takeSample(false, 5))    // 5 random elements
println(rdd.collect().length)        // 100 (brings all data to driver — use carefully!)

// saveAsTextFile — write to filesystem
rdd.saveAsTextFile("/tmp/output")

// foreach — run a function on each element (side effects)
rdd.take(5).foreach(println)

RDD Persistence

By default, Spark recomputes an RDD every time it's used in an action. Use cache() or persist() to store it in memory:

scala
import org.apache.spark.storage.StorageLevel

val expensive = sc.textFile("/large/file.txt")
  .flatMap(_.split(" "))
  .filter(_.length > 4)

// Cache in memory (equivalent to persist(StorageLevel.MEMORY_ONLY))
expensive.cache()

// First action — computes and caches
val count = expensive.count()

// Second action — uses cached data, no recomputation
val sample = expensive.take(10)

// Other storage levels:
expensive.persist(StorageLevel.MEMORY_AND_DISK)  // spill to disk if memory is full
expensive.persist(StorageLevel.DISK_ONLY)        // always on disk
expensive.persist(StorageLevel.MEMORY_ONLY_SER)  // serialized (uses less memory)

// Free the cache when done
expensive.unpersist()

When to Cache

Cache an RDD when:

  • It's used multiple times in your application
  • It's expensive to compute (reading from S3, complex transformations)
  • It fits in memory

Don't cache when the RDD is only used once.

Partitions and Repartitioning

scala
val rdd = sc.parallelize(1 to 100, 8)  // 8 partitions
println(rdd.getNumPartitions)  // 8

// Increase partitions (triggers shuffle)
val more = rdd.repartition(16)

// Decrease partitions (no shuffle — efficient)
val fewer = rdd.coalesce(4)

Frequently Asked Questions

Q: When should I use RDDs instead of DataFrames? Use DataFrames for structured/semi-structured data — they're faster (Spark's Catalyst optimizer can plan SQL-style queries efficiently) and easier to work with. Use RDDs when you need to process unstructured data (raw text, binary), when you need fine-grained control over partitioning, when working with non-JVM-serializable types, or when using low-level Spark features not exposed in the DataFrame API.

Q: Why is groupByKey less efficient than reduceByKey? reduceByKey performs a partial aggregation on each partition before shuffling data across the network — only the aggregated results per key are transferred. groupByKey shuffles all raw values to the same partition before grouping, which can cause large data transfers and OOM errors on high-cardinality keys. Always prefer reduceByKey, aggregateByKey, or combineByKey over groupByKey when aggregating.

Q: What does "lazy evaluation" mean in practice for Spark RDDs? When you call map, filter, or flatMap, nothing actually runs — Spark just records the operation in a DAG (Directed Acyclic Graph). Only when you call an action like count() or collect() does Spark actually execute the DAG. This lets Spark optimize the full pipeline before running: merging adjacent maps, pushing filters early, and choosing efficient join strategies.


Part of Scala Mastery Course — Module 16 of 22.