data-processing/hm-spark/applications/find-taxi-top-routes-sql/src/main.py
from pyspark.sql import SparkSession
from utils.trip import load_trips, preprocess_trips
from utils.zone import load_zones, preprocess_zones
def main(
data_dirname: str,
trip_filenames: list[str],
zone_filename: str,
) -> None:
trip_data_paths = [f"{data_dirname}/{f}" for f in trip_filenames]
zone_data_path = f"{data_dirname}/{zone_filename}"
spark = SparkSession.builder.getOrCreate()
trips = load_trips(spark, trip_data_paths)
zones = load_zones(spark, zone_data_path)
trips = preprocess_trips(trips)
print((trips.count(), len(trips.columns)))
trips.show()
zones = preprocess_zones(zones)
print((zones.count(), len(zones.columns)))
zones.show()
# Get top routes
trips.createOrReplaceTempView("trips")
zones.createOrReplaceTempView("zones")
top_routes = spark.sql(
"""
with t2 as (
with t1 as (
select
pulocationid,
dolocationid,
count(*) as total
from trips
group by pulocationid, dolocationid
)
select
t1.pulocationid,
zones.zone as pulocation_zone,
zones.borough as pulocation_borough,
t1.dolocationid,
t1.total
from t1
inner join zones on t1.pulocationid = zones.locationid
)
select
t2.pulocationid,
t2.pulocation_zone,
t2.pulocation_borough,
t2.dolocationid,
zones.zone as dolocation_zone,
zones.borough as dolocation_borough,
t2.total
from t2
inner join zones on t2.dolocationid = zones.locationid
order by t2.total desc
"""
)
print((top_routes.count(), len(top_routes.columns)))
top_routes.show(truncate=False)
spark.stop()
if __name__ == "__main__":
# https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page
external_data_dirname = "data"
external_trip_filenames = [
"yellow_tripdata_2022-01.parquet",
"yellow_tripdata_2022-02.parquet",
"yellow_tripdata_2022-03.parquet",
"yellow_tripdata_2022-04.parquet",
"yellow_tripdata_2022-05.parquet",
"yellow_tripdata_2022-06.parquet",
]
external_zone_filename = "taxi_zones.csv"
main(
external_data_dirname,
external_trip_filenames,
external_zone_filename,
)